github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/trigger.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Support for Apache Beam triggers.
    19  
    20  Triggers control when in processing time windows get emitted.
    21  """
    22  
    23  # pytype: skip-file
    24  
    25  import collections
    26  import copy
    27  import logging
    28  import numbers
    29  from abc import ABCMeta
    30  from abc import abstractmethod
    31  from collections import abc as collections_abc  # ambiguty with direct abc
    32  from enum import Flag
    33  from enum import auto
    34  from itertools import zip_longest
    35  
    36  from apache_beam.coders import coder_impl
    37  from apache_beam.coders import observable
    38  from apache_beam.portability.api import beam_runner_api_pb2
    39  from apache_beam.transforms import combiners
    40  from apache_beam.transforms import core
    41  from apache_beam.transforms.timeutil import TimeDomain
    42  from apache_beam.transforms.window import GlobalWindow
    43  from apache_beam.transforms.window import GlobalWindows
    44  from apache_beam.transforms.window import TimestampCombiner
    45  from apache_beam.transforms.window import WindowedValue
    46  from apache_beam.transforms.window import WindowFn
    47  from apache_beam.utils import windowed_value
    48  from apache_beam.utils.timestamp import MAX_TIMESTAMP
    49  from apache_beam.utils.timestamp import MIN_TIMESTAMP
    50  from apache_beam.utils.timestamp import TIME_GRANULARITY
    51  
    52  __all__ = [
    53      'AccumulationMode',
    54      'TriggerFn',
    55      'DefaultTrigger',
    56      'AfterWatermark',
    57      'AfterProcessingTime',
    58      'AfterCount',
    59      'Repeatedly',
    60      'AfterAny',
    61      'AfterAll',
    62      'AfterEach',
    63      'OrFinally',
    64  ]
    65  
    66  _LOGGER = logging.getLogger(__name__)
    67  
    68  
    69  class AccumulationMode(object):
    70    """Controls what to do with data when a trigger fires multiple times."""
    71    DISCARDING = beam_runner_api_pb2.AccumulationMode.DISCARDING
    72    ACCUMULATING = beam_runner_api_pb2.AccumulationMode.ACCUMULATING
    73    # TODO(robertwb): Provide retractions of previous outputs.
    74    # RETRACTING = 3
    75  
    76  
    77  class _StateTag(metaclass=ABCMeta):
    78    """An identifier used to store and retrieve typed, combinable state.
    79  
    80    The given tag must be unique for this step."""
    81    def __init__(self, tag):
    82      self.tag = tag
    83  
    84  
    85  class _ReadModifyWriteStateTag(_StateTag):
    86    """StateTag pointing to an element."""
    87    def __repr__(self):
    88      return 'ValueStateTag(%s)' % (self.tag)
    89  
    90    def with_prefix(self, prefix):
    91      return _ReadModifyWriteStateTag(prefix + self.tag)
    92  
    93  
    94  class _SetStateTag(_StateTag):
    95    """StateTag pointing to an element."""
    96    def __repr__(self):
    97      return 'SetStateTag({tag})'.format(tag=self.tag)
    98  
    99    def with_prefix(self, prefix):
   100      return _SetStateTag(prefix + self.tag)
   101  
   102  
   103  class _CombiningValueStateTag(_StateTag):
   104    """StateTag pointing to an element, accumulated with a combiner.
   105  
   106    The given tag must be unique for this step. The given CombineFn will be
   107    applied (possibly incrementally and eagerly) when adding elements."""
   108  
   109    # TODO(robertwb): Also store the coder (perhaps extracted from the combine_fn)
   110    def __init__(self, tag, combine_fn):
   111      super().__init__(tag)
   112      if not combine_fn:
   113        raise ValueError('combine_fn must be specified.')
   114      if not isinstance(combine_fn, core.CombineFn):
   115        combine_fn = core.CombineFn.from_callable(combine_fn)
   116      self.combine_fn = combine_fn
   117  
   118    def __repr__(self):
   119      return 'CombiningValueStateTag(%s, %s)' % (self.tag, self.combine_fn)
   120  
   121    def with_prefix(self, prefix):
   122      return _CombiningValueStateTag(prefix + self.tag, self.combine_fn)
   123  
   124    def without_extraction(self):
   125      class NoExtractionCombineFn(core.CombineFn):
   126        setup = self.combine_fn.setup
   127        create_accumulator = self.combine_fn.create_accumulator
   128        add_input = self.combine_fn.add_input
   129        merge_accumulators = self.combine_fn.merge_accumulators
   130        compact = self.combine_fn.compact
   131        extract_output = staticmethod(lambda x: x)
   132        teardown = self.combine_fn.teardown
   133  
   134      return _CombiningValueStateTag(self.tag, NoExtractionCombineFn())
   135  
   136  
   137  class _ListStateTag(_StateTag):
   138    """StateTag pointing to a list of elements."""
   139    def __repr__(self):
   140      return 'ListStateTag(%s)' % self.tag
   141  
   142    def with_prefix(self, prefix):
   143      return _ListStateTag(prefix + self.tag)
   144  
   145  
   146  class _WatermarkHoldStateTag(_StateTag):
   147    def __init__(self, tag, timestamp_combiner_impl):
   148      super().__init__(tag)
   149      self.timestamp_combiner_impl = timestamp_combiner_impl
   150  
   151    def __repr__(self):
   152      return 'WatermarkHoldStateTag(%s, %s)' % (
   153          self.tag, self.timestamp_combiner_impl)
   154  
   155    def with_prefix(self, prefix):
   156      return _WatermarkHoldStateTag(
   157          prefix + self.tag, self.timestamp_combiner_impl)
   158  
   159  
   160  class DataLossReason(Flag):
   161    """Enum defining potential reasons that a trigger may cause data loss.
   162  
   163    These flags should only cover when the trigger is the cause, though windowing
   164    can be taken into account. For instance, AfterWatermark may not flag itself
   165    as finishing if the windowing doesn't allow lateness.
   166    """
   167  
   168    # Trigger will never be the source of data loss.
   169    NO_POTENTIAL_LOSS = 0
   170  
   171    # Trigger may finish. In this case, data that comes in after the trigger may
   172    # be lost. Example: AfterCount(1) will stop firing after the first element.
   173    MAY_FINISH = auto()
   174  
   175    # Deprecated: Beam will emit buffered data at GC time. Any other behavior
   176    # should be treated as a bug with the runner used.
   177    CONDITION_NOT_GUARANTEED = auto()
   178  
   179  
   180  # Convenience functions for checking if a flag is included. Each is equivalent
   181  # to `reason & flag == flag`
   182  
   183  
   184  def _IncludesMayFinish(reason):
   185    # type: (DataLossReason) -> bool
   186    return reason & DataLossReason.MAY_FINISH == DataLossReason.MAY_FINISH
   187  
   188  
   189  # pylint: disable=unused-argument
   190  # TODO(robertwb): Provisional API, Java likely to change as well.
   191  class TriggerFn(metaclass=ABCMeta):
   192    """A TriggerFn determines when window (panes) are emitted.
   193  
   194    See https://beam.apache.org/documentation/programming-guide/#triggers
   195    """
   196    @abstractmethod
   197    def on_element(self, element, window, context):
   198      """Called when a new element arrives in a window.
   199  
   200      Args:
   201        element: the element being added
   202        window: the window to which the element is being added
   203        context: a context (e.g. a TriggerContext instance) for managing state
   204            and setting timers
   205      """
   206      pass
   207  
   208    @abstractmethod
   209    def on_merge(self, to_be_merged, merge_result, context):
   210      """Called when multiple windows are merged.
   211  
   212      Args:
   213        to_be_merged: the set of windows to be merged
   214        merge_result: the window into which the windows are being merged
   215        context: a context (e.g. a TriggerContext instance) for managing state
   216            and setting timers
   217      """
   218      pass
   219  
   220    @abstractmethod
   221    def should_fire(self, time_domain, timestamp, window, context):
   222      """Whether this trigger should cause the window to fire.
   223  
   224      Args:
   225        time_domain: WATERMARK for event-time timers and REAL_TIME for
   226            processing-time timers.
   227        timestamp: for time_domain WATERMARK, it represents the
   228            watermark: (a lower bound on) the watermark of the system
   229            and for time_domain REAL_TIME, it represents the
   230            trigger: timestamp of the processing-time timer.
   231        window: the window whose trigger is being considered
   232        context: a context (e.g. a TriggerContext instance) for managing state
   233            and setting timers
   234  
   235      Returns:
   236        whether this trigger should cause a firing
   237      """
   238      pass
   239  
   240    @abstractmethod
   241    def has_ontime_pane(self):
   242      """Whether this trigger creates an empty pane even if there are no elements.
   243  
   244      Returns:
   245        True if this trigger guarantees that there will always be an ON_TIME pane
   246        even if there are no elements in that pane.
   247      """
   248      pass
   249  
   250    @abstractmethod
   251    def on_fire(self, watermark, window, context):
   252      """Called when a trigger actually fires.
   253  
   254      Args:
   255        watermark: (a lower bound on) the watermark of the system
   256        window: the window whose trigger is being fired
   257        context: a context (e.g. a TriggerContext instance) for managing state
   258            and setting timers
   259  
   260      Returns:
   261        whether this trigger is finished
   262      """
   263      pass
   264  
   265    @abstractmethod
   266    def reset(self, window, context):
   267      """Clear any state and timers used by this TriggerFn."""
   268      pass
   269  
   270    def may_lose_data(self, unused_windowing):
   271      # type: (core.Windowing) -> DataLossReason
   272  
   273      """Returns whether or not this trigger could cause data loss.
   274  
   275      A trigger can cause data loss in the following scenarios:
   276  
   277          * The trigger has a chance to finish. For instance, AfterWatermark()
   278            without a late trigger would cause all late data to be lost. This
   279            scenario is only accounted for if the windowing strategy allows
   280            late data. Otherwise, the trigger is not responsible for the data
   281            loss.
   282  
   283      Note that this only returns the potential for loss. It does not mean that
   284      there will be data loss. It also only accounts for loss related to the
   285      trigger, not other potential causes.
   286  
   287      Args:
   288        windowing: The Windowing that this trigger belongs to. It does not need
   289          to be the top-level trigger.
   290  
   291      Returns:
   292        The DataLossReason. If there is no potential loss,
   293          DataLossReason.NO_POTENTIAL_LOSS is returned. Otherwise, all the
   294          potential reasons are returned as a single value.
   295      """
   296      # For backwards compatibility's sake, we're assuming the trigger is safe.
   297      return DataLossReason.NO_POTENTIAL_LOSS
   298  
   299  
   300  # pylint: enable=unused-argument
   301  
   302    @staticmethod
   303    def from_runner_api(proto, context):
   304      return {
   305          'after_all': AfterAll,
   306          'after_any': AfterAny,
   307          'after_each': AfterEach,
   308          'after_end_of_window': AfterWatermark,
   309          'after_processing_time': AfterProcessingTime,
   310          # after_processing_time, after_synchronized_processing_time
   311          'always': Always,
   312          'default': DefaultTrigger,
   313          'element_count': AfterCount,
   314          'never': _Never,
   315          'or_finally': OrFinally,
   316          'repeat': Repeatedly,
   317      }[proto.WhichOneof('trigger')].from_runner_api(proto, context)
   318  
   319    @abstractmethod
   320    def to_runner_api(self, unused_context):
   321      pass
   322  
   323  
   324  class DefaultTrigger(TriggerFn):
   325    """Semantically Repeatedly(AfterWatermark()), but more optimized."""
   326    def __init__(self):
   327      pass
   328  
   329    def __repr__(self):
   330      return 'DefaultTrigger()'
   331  
   332    def on_element(self, element, window, context):
   333      context.set_timer(str(window), TimeDomain.WATERMARK, window.end)
   334  
   335    def on_merge(self, to_be_merged, merge_result, context):
   336      for window in to_be_merged:
   337        context.clear_timer(str(window), TimeDomain.WATERMARK)
   338  
   339    def should_fire(self, time_domain, watermark, window, context):
   340      if watermark >= window.end:
   341        # Explicitly clear the timer so that late elements are not emitted again
   342        # when the timer is fired.
   343        context.clear_timer(str(window), TimeDomain.WATERMARK)
   344      return watermark >= window.end
   345  
   346    def on_fire(self, watermark, window, context):
   347      return False
   348  
   349    def reset(self, window, context):
   350      context.clear_timer(str(window), TimeDomain.WATERMARK)
   351  
   352    def may_lose_data(self, unused_windowing):
   353      return DataLossReason.NO_POTENTIAL_LOSS
   354  
   355    def __eq__(self, other):
   356      return type(self) == type(other)
   357  
   358    def __hash__(self):
   359      return hash(type(self))
   360  
   361    @staticmethod
   362    def from_runner_api(proto, context):
   363      return DefaultTrigger()
   364  
   365    def to_runner_api(self, unused_context):
   366      return beam_runner_api_pb2.Trigger(
   367          default=beam_runner_api_pb2.Trigger.Default())
   368  
   369    def has_ontime_pane(self):
   370      return True
   371  
   372  
   373  class AfterProcessingTime(TriggerFn):
   374    """Fire exactly once after a specified delay from processing time."""
   375  
   376    STATE_TAG = _SetStateTag('has_timer')
   377  
   378    def __init__(self, delay=0):
   379      """Initialize a processing time trigger with a delay in seconds."""
   380      self.delay = delay
   381  
   382    def __repr__(self):
   383      return 'AfterProcessingTime(delay=%d)' % self.delay
   384  
   385    def on_element(self, element, window, context):
   386      if not context.get_state(self.STATE_TAG):
   387        context.set_timer(
   388            '', TimeDomain.REAL_TIME, context.get_current_time() + self.delay)
   389      context.add_state(self.STATE_TAG, True)
   390  
   391    def on_merge(self, to_be_merged, merge_result, context):
   392      # timers will be kept through merging
   393      pass
   394  
   395    def should_fire(self, time_domain, timestamp, window, context):
   396      if time_domain == TimeDomain.REAL_TIME:
   397        return True
   398  
   399    def on_fire(self, timestamp, window, context):
   400      return True
   401  
   402    def reset(self, window, context):
   403      context.clear_state(self.STATE_TAG)
   404  
   405    def may_lose_data(self, unused_windowing):
   406      """AfterProcessingTime may finish."""
   407      return DataLossReason.MAY_FINISH
   408  
   409    @staticmethod
   410    def from_runner_api(proto, context):
   411      return AfterProcessingTime(
   412          delay=(
   413              proto.after_processing_time.timestamp_transforms[0].delay.
   414              delay_millis) // 1000)
   415  
   416    def to_runner_api(self, context):
   417      delay_proto = beam_runner_api_pb2.TimestampTransform(
   418          delay=beam_runner_api_pb2.TimestampTransform.Delay(
   419              delay_millis=self.delay * 1000))
   420      return beam_runner_api_pb2.Trigger(
   421          after_processing_time=beam_runner_api_pb2.Trigger.AfterProcessingTime(
   422              timestamp_transforms=[delay_proto]))
   423  
   424    def has_ontime_pane(self):
   425      return False
   426  
   427  
   428  class Always(TriggerFn):
   429    """Repeatedly invoke the given trigger, never finishing."""
   430    def __init__(self):
   431      pass
   432  
   433    def __repr__(self):
   434      return 'Always'
   435  
   436    def __eq__(self, other):
   437      return type(self) == type(other)
   438  
   439    def __hash__(self):
   440      return 1
   441  
   442    def on_element(self, element, window, context):
   443      pass
   444  
   445    def on_merge(self, to_be_merged, merge_result, context):
   446      pass
   447  
   448    def has_ontime_pane(self):
   449      return False
   450  
   451    def reset(self, window, context):
   452      pass
   453  
   454    def should_fire(self, time_domain, watermark, window, context):
   455      return True
   456  
   457    def on_fire(self, watermark, window, context):
   458      return False
   459  
   460    def may_lose_data(self, unused_windowing):
   461      """No potential loss, since the trigger always fires."""
   462      return DataLossReason.NO_POTENTIAL_LOSS
   463  
   464    @staticmethod
   465    def from_runner_api(proto, context):
   466      return Always()
   467  
   468    def to_runner_api(self, context):
   469      return beam_runner_api_pb2.Trigger(
   470          always=beam_runner_api_pb2.Trigger.Always())
   471  
   472  
   473  class _Never(TriggerFn):
   474    """A trigger that never fires.
   475  
   476    Data may still be released at window closing.
   477    """
   478    def __init__(self):
   479      pass
   480  
   481    def __repr__(self):
   482      return 'Never'
   483  
   484    def __eq__(self, other):
   485      return type(self) == type(other)
   486  
   487    def __hash__(self):
   488      return hash(type(self))
   489  
   490    def on_element(self, element, window, context):
   491      pass
   492  
   493    def on_merge(self, to_be_merged, merge_result, context):
   494      pass
   495  
   496    def has_ontime_pane(self):
   497      False
   498  
   499    def reset(self, window, context):
   500      pass
   501  
   502    def should_fire(self, time_domain, watermark, window, context):
   503      return False
   504  
   505    def on_fire(self, watermark, window, context):
   506      return True
   507  
   508    def may_lose_data(self, unused_windowing):
   509      """No potential data loss.
   510  
   511      Though Never doesn't explicitly trigger, it still collects data on
   512      windowing closing.
   513      """
   514      return DataLossReason.NO_POTENTIAL_LOSS
   515  
   516    @staticmethod
   517    def from_runner_api(proto, context):
   518      return _Never()
   519  
   520    def to_runner_api(self, context):
   521      return beam_runner_api_pb2.Trigger(
   522          never=beam_runner_api_pb2.Trigger.Never())
   523  
   524  
   525  class AfterWatermark(TriggerFn):
   526    """Fire exactly once when the watermark passes the end of the window.
   527  
   528    Args:
   529        early: if not None, a speculative trigger to repeatedly evaluate before
   530          the watermark passes the end of the window
   531        late: if not None, a speculative trigger to repeatedly evaluate after
   532          the watermark passes the end of the window
   533    """
   534    LATE_TAG = _CombiningValueStateTag('is_late', any)
   535  
   536    def __init__(self, early=None, late=None):
   537      # TODO(zhoufek): Maybe don't wrap early/late if they are already Repeatedly
   538      self.early = Repeatedly(early) if early else None
   539      self.late = Repeatedly(late) if late else None
   540  
   541    def __repr__(self):
   542      qualifiers = []
   543      if self.early:
   544        qualifiers.append('early=%s' % self.early.underlying)
   545      if self.late:
   546        qualifiers.append('late=%s' % self.late.underlying)
   547      return 'AfterWatermark(%s)' % ', '.join(qualifiers)
   548  
   549    def is_late(self, context):
   550      return self.late and context.get_state(self.LATE_TAG)
   551  
   552    def on_element(self, element, window, context):
   553      if self.is_late(context):
   554        self.late.on_element(element, window, NestedContext(context, 'late'))
   555      else:
   556        context.set_timer('', TimeDomain.WATERMARK, window.end)
   557        if self.early:
   558          self.early.on_element(element, window, NestedContext(context, 'early'))
   559  
   560    def on_merge(self, to_be_merged, merge_result, context):
   561      # TODO(robertwb): Figure out whether the 'rewind' semantics could be used
   562      # here.
   563      if self.is_late(context):
   564        self.late.on_merge(
   565            to_be_merged, merge_result, NestedContext(context, 'late'))
   566      else:
   567        # Note: Timer clearing solely an optimization.
   568        for window in to_be_merged:
   569          if window.end != merge_result.end:
   570            context.clear_timer('', TimeDomain.WATERMARK)
   571        if self.early:
   572          self.early.on_merge(
   573              to_be_merged, merge_result, NestedContext(context, 'early'))
   574  
   575    def should_fire(self, time_domain, watermark, window, context):
   576      if self.is_late(context):
   577        return self.late.should_fire(
   578            time_domain, watermark, window, NestedContext(context, 'late'))
   579      elif watermark >= window.end:
   580        # Explicitly clear the timer so that late elements are not emitted again
   581        # when the timer is fired.
   582        context.clear_timer('', TimeDomain.WATERMARK)
   583        return True
   584      elif self.early:
   585        return self.early.should_fire(
   586            time_domain, watermark, window, NestedContext(context, 'early'))
   587      return False
   588  
   589    def on_fire(self, watermark, window, context):
   590      if self.is_late(context):
   591        return self.late.on_fire(
   592            watermark, window, NestedContext(context, 'late'))
   593      elif watermark >= window.end:
   594        context.add_state(self.LATE_TAG, True)
   595        return not self.late
   596      elif self.early:
   597        self.early.on_fire(watermark, window, NestedContext(context, 'early'))
   598        return False
   599  
   600    def reset(self, window, context):
   601      if self.late:
   602        context.clear_state(self.LATE_TAG)
   603      if self.early:
   604        self.early.reset(window, NestedContext(context, 'early'))
   605      if self.late:
   606        self.late.reset(window, NestedContext(context, 'late'))
   607  
   608    def may_lose_data(self, windowing):
   609      """May cause data loss if lateness allowed and no late trigger set."""
   610      if windowing.allowed_lateness == 0:
   611        return DataLossReason.NO_POTENTIAL_LOSS
   612      if self.late is None:
   613        return DataLossReason.MAY_FINISH
   614      return self.late.may_lose_data(windowing)
   615  
   616    def __eq__(self, other):
   617      return (
   618          type(self) == type(other) and self.early == other.early and
   619          self.late == other.late)
   620  
   621    def __hash__(self):
   622      return hash((type(self), self.early, self.late))
   623  
   624    @staticmethod
   625    def from_runner_api(proto, context):
   626      return AfterWatermark(
   627          early=TriggerFn.from_runner_api(
   628              proto.after_end_of_window.early_firings, context)
   629          if proto.after_end_of_window.HasField('early_firings') else None,
   630          late=TriggerFn.from_runner_api(
   631              proto.after_end_of_window.late_firings, context)
   632          if proto.after_end_of_window.HasField('late_firings') else None)
   633  
   634    def to_runner_api(self, context):
   635      early_proto = self.early.underlying.to_runner_api(
   636          context) if self.early else None
   637      late_proto = self.late.underlying.to_runner_api(
   638          context) if self.late else None
   639      return beam_runner_api_pb2.Trigger(
   640          after_end_of_window=beam_runner_api_pb2.Trigger.AfterEndOfWindow(
   641              early_firings=early_proto, late_firings=late_proto))
   642  
   643    def has_ontime_pane(self):
   644      return True
   645  
   646  
   647  class AfterCount(TriggerFn):
   648    """Fire when there are at least count elements in this window pane."""
   649  
   650    COUNT_TAG = _CombiningValueStateTag('count', combiners.CountCombineFn())
   651  
   652    def __init__(self, count):
   653      if not isinstance(count, numbers.Integral) or count < 1:
   654        raise ValueError("count (%d) must be a positive integer." % count)
   655      self.count = count
   656  
   657    def __repr__(self):
   658      return 'AfterCount(%s)' % self.count
   659  
   660    def __eq__(self, other):
   661      return type(self) == type(other) and self.count == other.count
   662  
   663    def __hash__(self):
   664      return hash(self.count)
   665  
   666    def on_element(self, element, window, context):
   667      context.add_state(self.COUNT_TAG, 1)
   668  
   669    def on_merge(self, to_be_merged, merge_result, context):
   670      # states automatically merged
   671      pass
   672  
   673    def should_fire(self, time_domain, watermark, window, context):
   674      return context.get_state(self.COUNT_TAG) >= self.count
   675  
   676    def on_fire(self, watermark, window, context):
   677      return True
   678  
   679    def reset(self, window, context):
   680      context.clear_state(self.COUNT_TAG)
   681  
   682    def may_lose_data(self, unused_windowing):
   683      """AfterCount may finish."""
   684      return DataLossReason.MAY_FINISH
   685  
   686    @staticmethod
   687    def from_runner_api(proto, unused_context):
   688      return AfterCount(proto.element_count.element_count)
   689  
   690    def to_runner_api(self, unused_context):
   691      return beam_runner_api_pb2.Trigger(
   692          element_count=beam_runner_api_pb2.Trigger.ElementCount(
   693              element_count=self.count))
   694  
   695    def has_ontime_pane(self):
   696      return False
   697  
   698  
   699  class Repeatedly(TriggerFn):
   700    """Repeatedly invoke the given trigger, never finishing."""
   701    def __init__(self, underlying):
   702      self.underlying = underlying
   703  
   704    def __repr__(self):
   705      return 'Repeatedly(%s)' % self.underlying
   706  
   707    def __eq__(self, other):
   708      return type(self) == type(other) and self.underlying == other.underlying
   709  
   710    def __hash__(self):
   711      return hash(self.underlying)
   712  
   713    def on_element(self, element, window, context):
   714      self.underlying.on_element(element, window, context)
   715  
   716    def on_merge(self, to_be_merged, merge_result, context):
   717      self.underlying.on_merge(to_be_merged, merge_result, context)
   718  
   719    def should_fire(self, time_domain, watermark, window, context):
   720      return self.underlying.should_fire(time_domain, watermark, window, context)
   721  
   722    def on_fire(self, watermark, window, context):
   723      if self.underlying.on_fire(watermark, window, context):
   724        self.underlying.reset(window, context)
   725      return False
   726  
   727    def reset(self, window, context):
   728      self.underlying.reset(window, context)
   729  
   730    def may_lose_data(self, windowing):
   731      """Repeatedly will run in a loop and pick up whatever is left at GC."""
   732      return DataLossReason.NO_POTENTIAL_LOSS
   733  
   734    @staticmethod
   735    def from_runner_api(proto, context):
   736      return Repeatedly(
   737          TriggerFn.from_runner_api(proto.repeat.subtrigger, context))
   738  
   739    def to_runner_api(self, context):
   740      return beam_runner_api_pb2.Trigger(
   741          repeat=beam_runner_api_pb2.Trigger.Repeat(
   742              subtrigger=self.underlying.to_runner_api(context)))
   743  
   744    def has_ontime_pane(self):
   745      return self.underlying.has_ontime_pane()
   746  
   747  
   748  class _ParallelTriggerFn(TriggerFn, metaclass=ABCMeta):
   749    def __init__(self, *triggers):
   750      self.triggers = triggers
   751  
   752    def __repr__(self):
   753      return '%s(%s)' % (
   754          self.__class__.__name__, ', '.join(str(t) for t in self.triggers))
   755  
   756    def __eq__(self, other):
   757      return type(self) == type(other) and self.triggers == other.triggers
   758  
   759    def __hash__(self):
   760      return hash(self.triggers)
   761  
   762    @abstractmethod
   763    def combine_op(self, trigger_results):
   764      pass
   765  
   766    def on_element(self, element, window, context):
   767      for ix, trigger in enumerate(self.triggers):
   768        trigger.on_element(element, window, self._sub_context(context, ix))
   769  
   770    def on_merge(self, to_be_merged, merge_result, context):
   771      for ix, trigger in enumerate(self.triggers):
   772        trigger.on_merge(
   773            to_be_merged, merge_result, self._sub_context(context, ix))
   774  
   775    def should_fire(self, time_domain, watermark, window, context):
   776      self._time_domain = time_domain
   777      return self.combine_op(
   778          trigger.should_fire(
   779              time_domain, watermark, window, self._sub_context(context, ix))
   780          for ix,
   781          trigger in enumerate(self.triggers))
   782  
   783    def on_fire(self, watermark, window, context):
   784      finished = []
   785      for ix, trigger in enumerate(self.triggers):
   786        nested_context = self._sub_context(context, ix)
   787        if trigger.should_fire(TimeDomain.WATERMARK,
   788                               watermark,
   789                               window,
   790                               nested_context):
   791          finished.append(trigger.on_fire(watermark, window, nested_context))
   792      return self.combine_op(finished)
   793  
   794    def may_lose_data(self, windowing):
   795      may_finish = self.combine_op(
   796          _IncludesMayFinish(t.may_lose_data(windowing)) for t in self.triggers)
   797      return (
   798          DataLossReason.MAY_FINISH
   799          if may_finish else DataLossReason.NO_POTENTIAL_LOSS)
   800  
   801    def reset(self, window, context):
   802      for ix, trigger in enumerate(self.triggers):
   803        trigger.reset(window, self._sub_context(context, ix))
   804  
   805    @staticmethod
   806    def _sub_context(context, index):
   807      return NestedContext(context, '%d/' % index)
   808  
   809    @staticmethod
   810    def from_runner_api(proto, context):
   811      subtriggers = [
   812          TriggerFn.from_runner_api(subtrigger, context) for subtrigger in
   813          proto.after_all.subtriggers or proto.after_any.subtriggers
   814      ]
   815      if proto.after_all.subtriggers:
   816        return AfterAll(*subtriggers)
   817      else:
   818        return AfterAny(*subtriggers)
   819  
   820    def to_runner_api(self, context):
   821      subtriggers = [
   822          subtrigger.to_runner_api(context) for subtrigger in self.triggers
   823      ]
   824      if self.combine_op == all:
   825        return beam_runner_api_pb2.Trigger(
   826            after_all=beam_runner_api_pb2.Trigger.AfterAll(
   827                subtriggers=subtriggers))
   828      elif self.combine_op == any:
   829        return beam_runner_api_pb2.Trigger(
   830            after_any=beam_runner_api_pb2.Trigger.AfterAny(
   831                subtriggers=subtriggers))
   832      else:
   833        raise NotImplementedError(self)
   834  
   835    def has_ontime_pane(self):
   836      return any(t.has_ontime_pane() for t in self.triggers)
   837  
   838  
   839  class AfterAny(_ParallelTriggerFn):
   840    """Fires when any subtrigger fires.
   841  
   842    Also finishes when any subtrigger finishes.
   843    """
   844    combine_op = any
   845  
   846  
   847  class AfterAll(_ParallelTriggerFn):
   848    """Fires when all subtriggers have fired.
   849  
   850    Also finishes when all subtriggers have finished.
   851    """
   852    combine_op = all
   853  
   854  
   855  class AfterEach(TriggerFn):
   856  
   857    INDEX_TAG = _CombiningValueStateTag(
   858        'index', (lambda indices: 0 if not indices else max(indices)))
   859  
   860    def __init__(self, *triggers):
   861      self.triggers = triggers
   862  
   863    def __repr__(self):
   864      return '%s(%s)' % (
   865          self.__class__.__name__, ', '.join(str(t) for t in self.triggers))
   866  
   867    def __eq__(self, other):
   868      return type(self) == type(other) and self.triggers == other.triggers
   869  
   870    def __hash__(self):
   871      return hash(self.triggers)
   872  
   873    def on_element(self, element, window, context):
   874      ix = context.get_state(self.INDEX_TAG)
   875      if ix < len(self.triggers):
   876        self.triggers[ix].on_element(
   877            element, window, self._sub_context(context, ix))
   878  
   879    def on_merge(self, to_be_merged, merge_result, context):
   880      # This takes the furthest window on merging.
   881      # TODO(robertwb): Revisit this when merging windows logic is settled for
   882      # all possible merging situations.
   883      ix = context.get_state(self.INDEX_TAG)
   884      if ix < len(self.triggers):
   885        self.triggers[ix].on_merge(
   886            to_be_merged, merge_result, self._sub_context(context, ix))
   887  
   888    def should_fire(self, time_domain, watermark, window, context):
   889      ix = context.get_state(self.INDEX_TAG)
   890      if ix < len(self.triggers):
   891        return self.triggers[ix].should_fire(
   892            time_domain, watermark, window, self._sub_context(context, ix))
   893  
   894    def on_fire(self, watermark, window, context):
   895      ix = context.get_state(self.INDEX_TAG)
   896      if ix < len(self.triggers):
   897        if self.triggers[ix].on_fire(watermark,
   898                                     window,
   899                                     self._sub_context(context, ix)):
   900          ix += 1
   901          context.add_state(self.INDEX_TAG, ix)
   902        return ix == len(self.triggers)
   903  
   904    def reset(self, window, context):
   905      context.clear_state(self.INDEX_TAG)
   906      for ix, trigger in enumerate(self.triggers):
   907        trigger.reset(window, self._sub_context(context, ix))
   908  
   909    def may_lose_data(self, windowing):
   910      """If all sub-triggers may finish, this may finish."""
   911      may_finish = all(
   912          _IncludesMayFinish(t.may_lose_data(windowing)) for t in self.triggers)
   913      return (
   914          DataLossReason.MAY_FINISH
   915          if may_finish else DataLossReason.NO_POTENTIAL_LOSS)
   916  
   917    @staticmethod
   918    def _sub_context(context, index):
   919      return NestedContext(context, '%d/' % index)
   920  
   921    @staticmethod
   922    def from_runner_api(proto, context):
   923      return AfterEach(
   924          *[
   925              TriggerFn.from_runner_api(subtrigger, context)
   926              for subtrigger in proto.after_each.subtriggers
   927          ])
   928  
   929    def to_runner_api(self, context):
   930      return beam_runner_api_pb2.Trigger(
   931          after_each=beam_runner_api_pb2.Trigger.AfterEach(
   932              subtriggers=[
   933                  subtrigger.to_runner_api(context)
   934                  for subtrigger in self.triggers
   935              ]))
   936  
   937    def has_ontime_pane(self):
   938      return any(t.has_ontime_pane() for t in self.triggers)
   939  
   940  
   941  class OrFinally(AfterAny):
   942    @staticmethod
   943    def from_runner_api(proto, context):
   944      return OrFinally(
   945          TriggerFn.from_runner_api(proto.or_finally.main, context),
   946          # getattr is used as finally is a keyword in Python
   947          TriggerFn.from_runner_api(
   948              getattr(proto.or_finally, 'finally'), context))
   949  
   950    def to_runner_api(self, context):
   951      return beam_runner_api_pb2.Trigger(
   952          or_finally=beam_runner_api_pb2.Trigger.OrFinally(
   953              main=self.triggers[0].to_runner_api(context),
   954              # dict keyword argument is used as finally is a keyword in Python
   955              **{'finally': self.triggers[1].to_runner_api(context)}))
   956  
   957  
   958  class TriggerContext(object):
   959    def __init__(self, outer, window, clock):
   960      self._outer = outer
   961      self._window = window
   962      self._clock = clock
   963  
   964    def get_current_time(self):
   965      return self._clock.time()
   966  
   967    def set_timer(self, name, time_domain, timestamp):
   968      self._outer.set_timer(self._window, name, time_domain, timestamp)
   969  
   970    def clear_timer(self, name, time_domain):
   971      self._outer.clear_timer(self._window, name, time_domain)
   972  
   973    def add_state(self, tag, value):
   974      self._outer.add_state(self._window, tag, value)
   975  
   976    def get_state(self, tag):
   977      return self._outer.get_state(self._window, tag)
   978  
   979    def clear_state(self, tag):
   980      return self._outer.clear_state(self._window, tag)
   981  
   982  
   983  class NestedContext(object):
   984    """Namespaced context useful for defining composite triggers."""
   985    def __init__(self, outer, prefix):
   986      self._outer = outer
   987      self._prefix = prefix
   988  
   989    def get_current_time(self):
   990      return self._outer.get_current_time()
   991  
   992    def set_timer(self, name, time_domain, timestamp):
   993      self._outer.set_timer(self._prefix + name, time_domain, timestamp)
   994  
   995    def clear_timer(self, name, time_domain):
   996      self._outer.clear_timer(self._prefix + name, time_domain)
   997  
   998    def add_state(self, tag, value):
   999      self._outer.add_state(tag.with_prefix(self._prefix), value)
  1000  
  1001    def get_state(self, tag):
  1002      return self._outer.get_state(tag.with_prefix(self._prefix))
  1003  
  1004    def clear_state(self, tag):
  1005      self._outer.clear_state(tag.with_prefix(self._prefix))
  1006  
  1007  
  1008  # pylint: disable=unused-argument
  1009  class SimpleState(metaclass=ABCMeta):
  1010    """Basic state storage interface used for triggering.
  1011  
  1012    Only timers must hold the watermark (by their timestamp).
  1013    """
  1014    @abstractmethod
  1015    def set_timer(
  1016        self, window, name, time_domain, timestamp, dynamic_timer_tag=''):
  1017      pass
  1018  
  1019    @abstractmethod
  1020    def get_window(self, window_id):
  1021      pass
  1022  
  1023    @abstractmethod
  1024    def clear_timer(self, window, name, time_domain, dynamic_timer_tag=''):
  1025      pass
  1026  
  1027    @abstractmethod
  1028    def add_state(self, window, tag, value):
  1029      pass
  1030  
  1031    @abstractmethod
  1032    def get_state(self, window, tag):
  1033      pass
  1034  
  1035    @abstractmethod
  1036    def clear_state(self, window, tag):
  1037      pass
  1038  
  1039    def at(self, window, clock):
  1040      return NestedContext(TriggerContext(self, window, clock), 'trigger')
  1041  
  1042  
  1043  class UnmergedState(SimpleState):
  1044    """State suitable for use in TriggerDriver.
  1045  
  1046    This class must be implemented by each backend.
  1047    """
  1048    @abstractmethod
  1049    def set_global_state(self, tag, value):
  1050      pass
  1051  
  1052    @abstractmethod
  1053    def get_global_state(self, tag, default=None):
  1054      pass
  1055  
  1056  
  1057  # pylint: enable=unused-argument
  1058  
  1059  
  1060  class MergeableStateAdapter(SimpleState):
  1061    """Wraps an UnmergedState, tracking merged windows."""
  1062    # TODO(robertwb): A similar indirection could be used for sliding windows
  1063    # or other window_fns when a single element typically belongs to many windows.
  1064  
  1065    WINDOW_IDS = _ReadModifyWriteStateTag('window_ids')
  1066  
  1067    def __init__(self, raw_state):
  1068      self.raw_state = raw_state
  1069      self.window_ids = self.raw_state.get_global_state(self.WINDOW_IDS, {})
  1070      self.counter = None
  1071  
  1072    def set_timer(
  1073        self, window, name, time_domain, timestamp, dynamic_timer_tag=''):
  1074      self.raw_state.set_timer(
  1075          self._get_id(window),
  1076          name,
  1077          time_domain,
  1078          timestamp,
  1079          dynamic_timer_tag=dynamic_timer_tag)
  1080  
  1081    def clear_timer(self, window, name, time_domain, dynamic_timer_tag=''):
  1082      for window_id in self._get_ids(window):
  1083        self.raw_state.clear_timer(
  1084            window_id, name, time_domain, dynamic_timer_tag=dynamic_timer_tag)
  1085  
  1086    def add_state(self, window, tag, value):
  1087      if isinstance(tag, _ReadModifyWriteStateTag):
  1088        raise ValueError(
  1089            'Merging requested for non-mergeable state tag: %r.' % tag)
  1090      elif isinstance(tag, _CombiningValueStateTag):
  1091        tag = tag.without_extraction()
  1092      self.raw_state.add_state(self._get_id(window), tag, value)
  1093  
  1094    def get_state(self, window, tag):
  1095      if isinstance(tag, _CombiningValueStateTag):
  1096        original_tag, tag = tag, tag.without_extraction()
  1097      values = [
  1098          self.raw_state.get_state(window_id, tag)
  1099          for window_id in self._get_ids(window)
  1100      ]
  1101      if isinstance(tag, _ReadModifyWriteStateTag):
  1102        raise ValueError(
  1103            'Merging requested for non-mergeable state tag: %r.' % tag)
  1104      elif isinstance(tag, _CombiningValueStateTag):
  1105        return original_tag.combine_fn.extract_output(
  1106            original_tag.combine_fn.merge_accumulators(values))
  1107      elif isinstance(tag, _ListStateTag):
  1108        return [v for vs in values for v in vs]
  1109      elif isinstance(tag, _SetStateTag):
  1110        return {v for vs in values for v in vs}
  1111      elif isinstance(tag, _WatermarkHoldStateTag):
  1112        return tag.timestamp_combiner_impl.combine_all(values)
  1113      else:
  1114        raise ValueError('Invalid tag.', tag)
  1115  
  1116    def clear_state(self, window, tag):
  1117      for window_id in self._get_ids(window):
  1118        self.raw_state.clear_state(window_id, tag)
  1119      if tag is None:
  1120        del self.window_ids[window]
  1121        self._persist_window_ids()
  1122  
  1123    def merge(self, to_be_merged, merge_result):
  1124      for window in to_be_merged:
  1125        if window != merge_result:
  1126          if window in self.window_ids:
  1127            if merge_result in self.window_ids:
  1128              merge_window_ids = self.window_ids[merge_result]
  1129            else:
  1130              merge_window_ids = self.window_ids[merge_result] = []
  1131            merge_window_ids.extend(self.window_ids.pop(window))
  1132            self._persist_window_ids()
  1133  
  1134    def known_windows(self):
  1135      return list(self.window_ids)
  1136  
  1137    def get_window(self, window_id):
  1138      for window, ids in self.window_ids.items():
  1139        if window_id in ids:
  1140          return window
  1141      raise ValueError('No window for %s' % window_id)
  1142  
  1143    def _get_id(self, window):
  1144      if window in self.window_ids:
  1145        return self.window_ids[window][0]
  1146  
  1147      window_id = self._get_next_counter()
  1148      self.window_ids[window] = [window_id]
  1149      self._persist_window_ids()
  1150      return window_id
  1151  
  1152    def _get_ids(self, window):
  1153      return self.window_ids.get(window, [])
  1154  
  1155    def _get_next_counter(self):
  1156      if not self.window_ids:
  1157        self.counter = 0
  1158      elif self.counter is None:
  1159        self.counter = max(k for ids in self.window_ids.values() for k in ids)
  1160      self.counter += 1
  1161      return self.counter
  1162  
  1163    def _persist_window_ids(self):
  1164      self.raw_state.set_global_state(self.WINDOW_IDS, self.window_ids)
  1165  
  1166    def __repr__(self):
  1167      return '\n\t'.join([repr(self.window_ids)] +
  1168                         repr(self.raw_state).split('\n'))
  1169  
  1170  
  1171  def create_trigger_driver(
  1172      windowing, is_batch=False, phased_combine_fn=None, clock=None):
  1173    """Create the TriggerDriver for the given windowing and options."""
  1174  
  1175    # TODO(https://github.com/apache/beam/issues/20165): Respect closing and
  1176    # on-time behaviors. For batch, we should always fire once, no matter what.
  1177    if is_batch and windowing.triggerfn == _Never():
  1178      windowing = copy.copy(windowing)
  1179      windowing.triggerfn = Always()
  1180  
  1181    # TODO(robertwb): We can do more if we know elements are in timestamp
  1182    # sorted order.
  1183    if windowing.is_default() and is_batch:
  1184      driver = BatchGlobalTriggerDriver()
  1185    elif (windowing.windowfn == GlobalWindows() and
  1186          (windowing.triggerfn in [AfterCount(1), Always()]) and is_batch):
  1187      # Here we also just pass through all the values exactly once.
  1188      driver = BatchGlobalTriggerDriver()
  1189    else:
  1190      driver = GeneralTriggerDriver(windowing, clock)
  1191  
  1192    if phased_combine_fn:
  1193      # TODO(ccy): Refactor GeneralTriggerDriver to combine values eagerly using
  1194      # the known phased_combine_fn here.
  1195      driver = CombiningTriggerDriver(phased_combine_fn, driver)
  1196    return driver
  1197  
  1198  
  1199  class TriggerDriver(metaclass=ABCMeta):
  1200    """Breaks a series of bundle and timer firings into window (pane)s."""
  1201    @abstractmethod
  1202    def process_elements(
  1203        self,
  1204        state,
  1205        windowed_values,
  1206        output_watermark,
  1207        input_watermark=MIN_TIMESTAMP):
  1208      pass
  1209  
  1210    @abstractmethod
  1211    def process_timer(
  1212        self,
  1213        window_id,
  1214        name,
  1215        time_domain,
  1216        timestamp,
  1217        state,
  1218        input_watermark=None):
  1219      pass
  1220  
  1221    def process_entire_key(self, key, windowed_values):
  1222      # This state holds per-key, multi-window state.
  1223      state = InMemoryUnmergedState()
  1224      for wvalue in self.process_elements(state,
  1225                                          windowed_values,
  1226                                          MIN_TIMESTAMP,
  1227                                          MIN_TIMESTAMP):
  1228        yield wvalue.with_value((key, wvalue.value))
  1229      while state.timers:
  1230        fired = state.get_and_clear_timers()
  1231        for timer_window, (name, time_domain, fire_time, _) in fired:
  1232          for wvalue in self.process_timer(timer_window,
  1233                                           name,
  1234                                           time_domain,
  1235                                           fire_time,
  1236                                           state):
  1237            yield wvalue.with_value((key, wvalue.value))
  1238  
  1239  
  1240  class _UnwindowedValues(observable.ObservableMixin):
  1241    """Exposes iterable of windowed values as iterable of unwindowed values."""
  1242    def __init__(self, windowed_values):
  1243      super().__init__()
  1244      self._windowed_values = windowed_values
  1245  
  1246    def __iter__(self):
  1247      for wv in self._windowed_values:
  1248        unwindowed_value = wv.value
  1249        self.notify_observers(unwindowed_value)
  1250        yield unwindowed_value
  1251  
  1252    def __repr__(self):
  1253      return '<_UnwindowedValues of %s>' % self._windowed_values
  1254  
  1255    def __reduce__(self):
  1256      return list, (list(self), )
  1257  
  1258    def __eq__(self, other):
  1259      if isinstance(other, collections_abc.Iterable):
  1260        return all(
  1261            a == b for a, b in zip_longest(self, other, fillvalue=object()))
  1262      else:
  1263        return NotImplemented
  1264  
  1265    def __hash__(self):
  1266      return hash(tuple(self))
  1267  
  1268  
  1269  coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type(
  1270      _UnwindowedValues)
  1271  
  1272  
  1273  class BatchGlobalTriggerDriver(TriggerDriver):
  1274    """Groups all received values together.
  1275    """
  1276    GLOBAL_WINDOW_TUPLE = (GlobalWindow(), )
  1277    ONLY_FIRING = windowed_value.PaneInfo(
  1278        is_first=True,
  1279        is_last=True,
  1280        timing=windowed_value.PaneInfoTiming.ON_TIME,
  1281        index=0,
  1282        nonspeculative_index=0)
  1283  
  1284    def process_elements(
  1285        self,
  1286        state,
  1287        windowed_values,
  1288        unused_output_watermark,
  1289        unused_input_watermark=MIN_TIMESTAMP):
  1290      yield WindowedValue(
  1291          _UnwindowedValues(windowed_values),
  1292          MIN_TIMESTAMP,
  1293          self.GLOBAL_WINDOW_TUPLE,
  1294          self.ONLY_FIRING)
  1295  
  1296    def process_timer(
  1297        self,
  1298        window_id,
  1299        name,
  1300        time_domain,
  1301        timestamp,
  1302        state,
  1303        input_watermark=None):
  1304      raise TypeError('Triggers never set or called for batch default windowing.')
  1305  
  1306  
  1307  class CombiningTriggerDriver(TriggerDriver):
  1308    """Uses a phased_combine_fn to process output of wrapped TriggerDriver."""
  1309    def __init__(self, phased_combine_fn, underlying):
  1310      self.phased_combine_fn = phased_combine_fn
  1311      self.underlying = underlying
  1312  
  1313    def process_elements(
  1314        self,
  1315        state,
  1316        windowed_values,
  1317        output_watermark,
  1318        input_watermark=MIN_TIMESTAMP):
  1319      uncombined = self.underlying.process_elements(
  1320          state, windowed_values, output_watermark, input_watermark)
  1321      for output in uncombined:
  1322        yield output.with_value(self.phased_combine_fn.apply(output.value))
  1323  
  1324    def process_timer(
  1325        self,
  1326        window_id,
  1327        name,
  1328        time_domain,
  1329        timestamp,
  1330        state,
  1331        input_watermark=None):
  1332      uncombined = self.underlying.process_timer(
  1333          window_id, name, time_domain, timestamp, state, input_watermark)
  1334      for output in uncombined:
  1335        yield output.with_value(self.phased_combine_fn.apply(output.value))
  1336  
  1337  
  1338  class GeneralTriggerDriver(TriggerDriver):
  1339    """Breaks a series of bundle and timer firings into window (pane)s.
  1340  
  1341    Suitable for all variants of Windowing.
  1342    """
  1343    ELEMENTS = _ListStateTag('elements')
  1344    TOMBSTONE = _CombiningValueStateTag('tombstone', combiners.CountCombineFn())
  1345    INDEX = _CombiningValueStateTag('index', combiners.CountCombineFn())
  1346    NONSPECULATIVE_INDEX = _CombiningValueStateTag(
  1347        'nonspeculative_index', combiners.CountCombineFn())
  1348  
  1349    def __init__(self, windowing, clock):
  1350      self.clock = clock
  1351      self.allowed_lateness = windowing.allowed_lateness
  1352      self.window_fn = windowing.windowfn
  1353      self.timestamp_combiner_impl = TimestampCombiner.get_impl(
  1354          windowing.timestamp_combiner, self.window_fn)
  1355      # pylint: disable=invalid-name
  1356      self.WATERMARK_HOLD = _WatermarkHoldStateTag(
  1357          'watermark', self.timestamp_combiner_impl)
  1358      # pylint: enable=invalid-name
  1359      self.trigger_fn = windowing.triggerfn
  1360      self.accumulation_mode = windowing.accumulation_mode
  1361      self.is_merging = True
  1362  
  1363    def process_elements(
  1364        self,
  1365        state,
  1366        windowed_values,
  1367        output_watermark,
  1368        input_watermark=MIN_TIMESTAMP):
  1369      if self.is_merging:
  1370        state = MergeableStateAdapter(state)
  1371  
  1372      windows_to_elements = collections.defaultdict(list)
  1373      for wv in windowed_values:
  1374        for window in wv.windows:
  1375          # ignore expired windows
  1376          if input_watermark > window.end + self.allowed_lateness:
  1377            continue
  1378          windows_to_elements[window].append((wv.value, wv.timestamp))
  1379  
  1380      # First handle merging.
  1381      if self.is_merging:
  1382        old_windows = set(state.known_windows())
  1383        all_windows = old_windows.union(list(windows_to_elements))
  1384  
  1385        if all_windows != old_windows:
  1386          merged_away = {}
  1387  
  1388          class TriggerMergeContext(WindowFn.MergeContext):
  1389            def merge(_, to_be_merged, merge_result):  # pylint: disable=no-self-argument
  1390              for window in to_be_merged:
  1391                if window != merge_result:
  1392                  merged_away[window] = merge_result
  1393                  # Clear state associated with PaneInfo since it is
  1394                  # not preserved across merges.
  1395                  state.clear_state(window, self.INDEX)
  1396                  state.clear_state(window, self.NONSPECULATIVE_INDEX)
  1397              state.merge(to_be_merged, merge_result)
  1398              # using the outer self argument.
  1399              self.trigger_fn.on_merge(
  1400                  to_be_merged, merge_result, state.at(merge_result, self.clock))
  1401  
  1402          self.window_fn.merge(TriggerMergeContext(all_windows))
  1403  
  1404          merged_windows_to_elements = collections.defaultdict(list)
  1405          for window, values in windows_to_elements.items():
  1406            while window in merged_away:
  1407              window = merged_away[window]
  1408            merged_windows_to_elements[window].extend(values)
  1409          windows_to_elements = merged_windows_to_elements
  1410  
  1411          for window in merged_away:
  1412            state.clear_state(window, self.WATERMARK_HOLD)
  1413  
  1414      # Next handle element adding.
  1415      for window, elements in windows_to_elements.items():
  1416        if state.get_state(window, self.TOMBSTONE):
  1417          continue
  1418        # Add watermark hold.
  1419        # TODO(ccy): Add late data and garbage-collection hold support.
  1420        output_time = self.timestamp_combiner_impl.merge(
  1421            window,
  1422            (
  1423                element_output_time for element_output_time in (
  1424                    self.timestamp_combiner_impl.assign_output_time(
  1425                        window, timestamp) for unused_value,
  1426                    timestamp in elements)
  1427                if element_output_time >= output_watermark))
  1428        if output_time is not None:
  1429          state.add_state(window, self.WATERMARK_HOLD, output_time)
  1430  
  1431        context = state.at(window, self.clock)
  1432        for value, unused_timestamp in elements:
  1433          state.add_state(window, self.ELEMENTS, value)
  1434          self.trigger_fn.on_element(value, window, context)
  1435  
  1436        # Maybe fire this window.
  1437        if self.trigger_fn.should_fire(TimeDomain.WATERMARK,
  1438                                       input_watermark,
  1439                                       window,
  1440                                       context):
  1441          finished = self.trigger_fn.on_fire(input_watermark, window, context)
  1442          yield self._output(window, finished, state, output_watermark, False)
  1443  
  1444    def process_timer(
  1445        self,
  1446        window_id,
  1447        unused_name,
  1448        time_domain,
  1449        timestamp,
  1450        state,
  1451        input_watermark=None):
  1452      if input_watermark is None:
  1453        input_watermark = timestamp
  1454  
  1455      if self.is_merging:
  1456        state = MergeableStateAdapter(state)
  1457      window = state.get_window(window_id)
  1458      if state.get_state(window, self.TOMBSTONE):
  1459        return
  1460  
  1461      if time_domain in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME):
  1462        if not self.is_merging or window in state.known_windows():
  1463          context = state.at(window, self.clock)
  1464          if self.trigger_fn.should_fire(time_domain, timestamp, window, context):
  1465            finished = self.trigger_fn.on_fire(timestamp, window, context)
  1466            yield self._output(
  1467                window,
  1468                finished,
  1469                state,
  1470                timestamp,
  1471                time_domain == TimeDomain.WATERMARK)
  1472      else:
  1473        raise Exception('Unexpected time domain: %s' % time_domain)
  1474  
  1475    def _output(self, window, finished, state, output_watermark, maybe_ontime):
  1476      """Output window and clean up if appropriate."""
  1477      index = state.get_state(window, self.INDEX)
  1478      state.add_state(window, self.INDEX, 1)
  1479      if output_watermark <= window.max_timestamp():
  1480        nonspeculative_index = -1
  1481        timing = windowed_value.PaneInfoTiming.EARLY
  1482        if state.get_state(window, self.NONSPECULATIVE_INDEX):
  1483          nonspeculative_index = state.get_state(
  1484              window, self.NONSPECULATIVE_INDEX)
  1485          state.add_state(window, self.NONSPECULATIVE_INDEX, 1)
  1486          _LOGGER.warning(
  1487              'Watermark moved backwards in time '
  1488              'or late data moved window end forward.')
  1489      else:
  1490        nonspeculative_index = state.get_state(window, self.NONSPECULATIVE_INDEX)
  1491        state.add_state(window, self.NONSPECULATIVE_INDEX, 1)
  1492        timing = (
  1493            windowed_value.PaneInfoTiming.ON_TIME if maybe_ontime and
  1494            nonspeculative_index == 0 else windowed_value.PaneInfoTiming.LATE)
  1495      pane_info = windowed_value.PaneInfo(
  1496          index == 0, finished, timing, index, nonspeculative_index)
  1497  
  1498      values = state.get_state(window, self.ELEMENTS)
  1499      if finished:
  1500        # TODO(robertwb): allowed lateness
  1501        state.clear_state(window, self.ELEMENTS)
  1502        state.add_state(window, self.TOMBSTONE, 1)
  1503      elif self.accumulation_mode == AccumulationMode.DISCARDING:
  1504        state.clear_state(window, self.ELEMENTS)
  1505  
  1506      timestamp = state.get_state(window, self.WATERMARK_HOLD)
  1507      if timestamp is None:
  1508        # If no watermark hold was set, output at end of window.
  1509        timestamp = window.max_timestamp()
  1510      elif output_watermark < window.end and self.trigger_fn.has_ontime_pane():
  1511        # Hold the watermark in case there is an empty pane that needs to be fired
  1512        # at the end of the window.
  1513        pass
  1514      else:
  1515        state.clear_state(window, self.WATERMARK_HOLD)
  1516  
  1517      return WindowedValue(values, timestamp, (window, ), pane_info)
  1518  
  1519  
  1520  class InMemoryUnmergedState(UnmergedState):
  1521    """In-memory implementation of UnmergedState.
  1522  
  1523    Used for batch and testing.
  1524    """
  1525    def __init__(self, defensive_copy=False):
  1526      # TODO(robertwb): Clean defensive_copy. It is too expensive in production.
  1527      self.timers = collections.defaultdict(dict)
  1528      self.state = collections.defaultdict(lambda: collections.defaultdict(list))
  1529      self.global_state = {}
  1530      self.defensive_copy = defensive_copy
  1531  
  1532    def copy(self):
  1533      cloned_object = InMemoryUnmergedState(defensive_copy=self.defensive_copy)
  1534      cloned_object.timers = copy.deepcopy(self.timers)
  1535      cloned_object.global_state = copy.deepcopy(self.global_state)
  1536      for window in self.state:
  1537        for tag in self.state[window]:
  1538          cloned_object.state[window][tag] = copy.copy(self.state[window][tag])
  1539      return cloned_object
  1540  
  1541    def set_global_state(self, tag, value):
  1542      assert isinstance(tag, _ReadModifyWriteStateTag)
  1543      if self.defensive_copy:
  1544        value = copy.deepcopy(value)
  1545      self.global_state[tag.tag] = value
  1546  
  1547    def get_global_state(self, tag, default=None):
  1548      return self.global_state.get(tag.tag, default)
  1549  
  1550    def set_timer(
  1551        self, window, name, time_domain, timestamp, dynamic_timer_tag=''):
  1552      self.timers[window][(name, time_domain, dynamic_timer_tag)] = timestamp
  1553  
  1554    def clear_timer(self, window, name, time_domain, dynamic_timer_tag=''):
  1555      self.timers[window].pop((name, time_domain, dynamic_timer_tag), None)
  1556      if not self.timers[window]:
  1557        del self.timers[window]
  1558  
  1559    def get_window(self, window_id):
  1560      return window_id
  1561  
  1562    def add_state(self, window, tag, value):
  1563      if self.defensive_copy:
  1564        value = copy.deepcopy(value)
  1565      if isinstance(tag, _ReadModifyWriteStateTag):
  1566        self.state[window][tag.tag] = value
  1567      elif isinstance(tag, _CombiningValueStateTag):
  1568        # TODO(robertwb): Store merged accumulators.
  1569        self.state[window][tag.tag].append(value)
  1570      elif isinstance(tag, _ListStateTag):
  1571        self.state[window][tag.tag].append(value)
  1572      elif isinstance(tag, _SetStateTag):
  1573        self.state[window][tag.tag].append(value)
  1574      elif isinstance(tag, _WatermarkHoldStateTag):
  1575        self.state[window][tag.tag].append(value)
  1576      else:
  1577        raise ValueError('Invalid tag.', tag)
  1578  
  1579    def get_state(self, window, tag):
  1580      values = self.state[window][tag.tag]
  1581      if isinstance(tag, _ReadModifyWriteStateTag):
  1582        return values
  1583      elif isinstance(tag, _CombiningValueStateTag):
  1584        return tag.combine_fn.apply(values)
  1585      elif isinstance(tag, _ListStateTag):
  1586        return values
  1587      elif isinstance(tag, _SetStateTag):
  1588        return values
  1589      elif isinstance(tag, _WatermarkHoldStateTag):
  1590        return tag.timestamp_combiner_impl.combine_all(values)
  1591      else:
  1592        raise ValueError('Invalid tag.', tag)
  1593  
  1594    def clear_state(self, window, tag):
  1595      self.state[window].pop(tag.tag, None)
  1596      if not self.state[window]:
  1597        self.state.pop(window, None)
  1598  
  1599    def get_timers(
  1600        self, clear=False, watermark=MAX_TIMESTAMP, processing_time=None):
  1601      """Gets expired timers and reports if there
  1602      are any realtime timers set per state.
  1603  
  1604      Expiration is measured against the watermark for event-time timers,
  1605      and against a wall clock for processing-time timers.
  1606      """
  1607      expired = []
  1608      has_realtime_timer = False
  1609      for window, timers in list(self.timers.items()):
  1610        for (name, time_domain, dynamic_timer_tag), timestamp in list(
  1611            timers.items()):
  1612          if time_domain == TimeDomain.REAL_TIME:
  1613            time_marker = processing_time
  1614            has_realtime_timer = True
  1615          elif time_domain == TimeDomain.WATERMARK:
  1616            time_marker = watermark
  1617          else:
  1618            _LOGGER.error(
  1619                'TimeDomain error: No timers defined for time domain %s.',
  1620                time_domain)
  1621          if timestamp <= time_marker:
  1622            expired.append(
  1623                (window, (name, time_domain, timestamp, dynamic_timer_tag)))
  1624            if clear:
  1625              del timers[(name, time_domain, dynamic_timer_tag)]
  1626        if not timers and clear:
  1627          del self.timers[window]
  1628      return expired, has_realtime_timer
  1629  
  1630    def get_and_clear_timers(self, watermark=MAX_TIMESTAMP):
  1631      return self.get_timers(clear=True, watermark=watermark)[0]
  1632  
  1633    def get_earliest_hold(self):
  1634      earliest_hold = MAX_TIMESTAMP
  1635      for unused_window, tagged_states in self.state.items():
  1636        # TODO(https://github.com/apache/beam/issues/18441): currently, this
  1637        # assumes that the watermark hold tag is named "watermark".  This is
  1638        # currently only true because the only place watermark holds are set is
  1639        # in the GeneralTriggerDriver, where we use this name.  We should fix
  1640        # this by allowing enumeration of the tag types used in adding state.
  1641        if 'watermark' in tagged_states and tagged_states['watermark']:
  1642          hold = min(tagged_states['watermark']) - TIME_GRANULARITY
  1643          earliest_hold = min(earliest_hold, hold)
  1644      return earliest_hold
  1645  
  1646    def __repr__(self):
  1647      state_str = '\n'.join(
  1648          '%s: %s' % (key, dict(state)) for key, state in self.state.items())
  1649      return 'timers: %s\nstate: %s' % (dict(self.timers), state_str)