github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/userstate_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the Beam State and Timer API interfaces."""
    19  # pytype: skip-file
    20  
    21  import unittest
    22  from typing import Any
    23  from typing import List
    24  
    25  import mock
    26  import pytest
    27  
    28  import apache_beam as beam
    29  from apache_beam.coders import BytesCoder
    30  from apache_beam.coders import ListCoder
    31  from apache_beam.coders import StrUtf8Coder
    32  from apache_beam.coders import VarIntCoder
    33  from apache_beam.options.pipeline_options import PipelineOptions
    34  from apache_beam.portability import common_urns
    35  from apache_beam.portability.api import beam_runner_api_pb2
    36  from apache_beam.runners import pipeline_context
    37  from apache_beam.runners.common import DoFnSignature
    38  from apache_beam.testing.test_pipeline import TestPipeline
    39  from apache_beam.testing.test_stream import TestStream
    40  from apache_beam.testing.util import assert_that
    41  from apache_beam.testing.util import equal_to
    42  from apache_beam.transforms import trigger
    43  from apache_beam.transforms import userstate
    44  from apache_beam.transforms import window
    45  from apache_beam.transforms.combiners import ToListCombineFn
    46  from apache_beam.transforms.combiners import TopCombineFn
    47  from apache_beam.transforms.core import DoFn
    48  from apache_beam.transforms.timeutil import TimeDomain
    49  from apache_beam.transforms.userstate import BagStateSpec
    50  from apache_beam.transforms.userstate import CombiningValueStateSpec
    51  from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
    52  from apache_beam.transforms.userstate import SetStateSpec
    53  from apache_beam.transforms.userstate import TimerSpec
    54  from apache_beam.transforms.userstate import get_dofn_specs
    55  from apache_beam.transforms.userstate import is_stateful_dofn
    56  from apache_beam.transforms.userstate import on_timer
    57  from apache_beam.transforms.userstate import validate_stateful_dofn
    58  
    59  
    60  class TestStatefulDoFn(DoFn):
    61    """An example stateful DoFn with state and timers."""
    62  
    63    BUFFER_STATE_1 = BagStateSpec('buffer', BytesCoder())
    64    BUFFER_STATE_2 = BagStateSpec('buffer2', VarIntCoder())
    65    EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
    66    EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
    67    EXPIRY_TIMER_3 = TimerSpec('expiry3', TimeDomain.WATERMARK)
    68    EXPIRY_TIMER_FAMILY = TimerSpec('expiry_family', TimeDomain.WATERMARK)
    69  
    70    def process(
    71        self,
    72        element,
    73        t=DoFn.TimestampParam,
    74        buffer_1=DoFn.StateParam(BUFFER_STATE_1),
    75        buffer_2=DoFn.StateParam(BUFFER_STATE_2),
    76        timer_1=DoFn.TimerParam(EXPIRY_TIMER_1),
    77        timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
    78        dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)):
    79      yield element
    80  
    81    @on_timer(EXPIRY_TIMER_1)
    82    def on_expiry_1(
    83        self,
    84        window=DoFn.WindowParam,
    85        timestamp=DoFn.TimestampParam,
    86        key=DoFn.KeyParam,
    87        buffer=DoFn.StateParam(BUFFER_STATE_1),
    88        timer_1=DoFn.TimerParam(EXPIRY_TIMER_1),
    89        timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
    90        timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
    91      yield 'expired1'
    92  
    93    @on_timer(EXPIRY_TIMER_2)
    94    def on_expiry_2(
    95        self,
    96        buffer=DoFn.StateParam(BUFFER_STATE_2),
    97        timer_2=DoFn.TimerParam(EXPIRY_TIMER_2),
    98        timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
    99      yield 'expired2'
   100  
   101    @on_timer(EXPIRY_TIMER_3)
   102    def on_expiry_3(
   103        self,
   104        buffer_1=DoFn.StateParam(BUFFER_STATE_1),
   105        buffer_2=DoFn.StateParam(BUFFER_STATE_2),
   106        timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)):
   107      yield 'expired3'
   108  
   109    @on_timer(EXPIRY_TIMER_FAMILY)
   110    def on_expiry_family(
   111        self,
   112        dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY),
   113        dynamic_timer_tag=DoFn.DynamicTimerTagParam):
   114      yield (dynamic_timer_tag, 'expired_dynamic_timer')
   115  
   116  
   117  class InterfaceTest(unittest.TestCase):
   118    def _validate_dofn(self, dofn):
   119      # Construction of DoFnSignature performs validation of the given DoFn.
   120      # In particular, it ends up calling userstate._validate_stateful_dofn.
   121      # That behavior is explicitly tested below in test_validate_dofn()
   122      return DoFnSignature(dofn)
   123  
   124    @mock.patch('apache_beam.transforms.userstate.validate_stateful_dofn')
   125    def test_validate_dofn(self, unused_mock):
   126      dofn = TestStatefulDoFn()
   127      self._validate_dofn(dofn)
   128      userstate.validate_stateful_dofn.assert_called_with(dofn)
   129  
   130    def test_spec_construction(self):
   131      BagStateSpec('statename', VarIntCoder())
   132      with self.assertRaises(TypeError):
   133        BagStateSpec(123, VarIntCoder())
   134  
   135      CombiningValueStateSpec('statename', VarIntCoder(), TopCombineFn(10))
   136      with self.assertRaises(TypeError):
   137        CombiningValueStateSpec(123, VarIntCoder(), TopCombineFn(10))
   138      with self.assertRaises(TypeError):
   139        CombiningValueStateSpec('statename', VarIntCoder(), object())
   140  
   141      SetStateSpec('setstatename', VarIntCoder())
   142      with self.assertRaises(TypeError):
   143        SetStateSpec(123, VarIntCoder())
   144      with self.assertRaises(TypeError):
   145        SetStateSpec('setstatename', object())
   146  
   147      ReadModifyWriteStateSpec('valuestatename', VarIntCoder())
   148      with self.assertRaises(TypeError):
   149        ReadModifyWriteStateSpec(123, VarIntCoder())
   150      with self.assertRaises(TypeError):
   151        ReadModifyWriteStateSpec('valuestatename', object())
   152  
   153      # TODO: add more spec tests
   154      with self.assertRaises(ValueError):
   155        DoFn.TimerParam(BagStateSpec('elements', BytesCoder()))
   156  
   157      TimerSpec('timer', TimeDomain.WATERMARK)
   158      TimerSpec('timer', TimeDomain.REAL_TIME)
   159      with self.assertRaises(ValueError):
   160        TimerSpec('timer', 'bogus_time_domain')
   161      with self.assertRaises(ValueError):
   162        DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))
   163  
   164    def test_state_spec_proto_conversion(self):
   165      context = pipeline_context.PipelineContext()
   166      state = BagStateSpec('statename', VarIntCoder())
   167      state_proto = state.to_runner_api(context)
   168      self.assertEqual(
   169          beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn),
   170          state_proto.protocol)
   171  
   172      context = pipeline_context.PipelineContext()
   173      state = CombiningValueStateSpec(
   174          'statename', VarIntCoder(), TopCombineFn(10))
   175      state_proto = state.to_runner_api(context)
   176      self.assertEqual(
   177          beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn),
   178          state_proto.protocol)
   179  
   180      context = pipeline_context.PipelineContext()
   181      state = SetStateSpec('setstatename', VarIntCoder())
   182      state_proto = state.to_runner_api(context)
   183      self.assertEqual(
   184          beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn),
   185          state_proto.protocol)
   186  
   187      context = pipeline_context.PipelineContext()
   188      state = ReadModifyWriteStateSpec('valuestatename', VarIntCoder())
   189      state_proto = state.to_runner_api(context)
   190      self.assertEqual(
   191          beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn),
   192          state_proto.protocol)
   193  
   194    def test_param_construction(self):
   195      with self.assertRaises(ValueError):
   196        DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))
   197      with self.assertRaises(ValueError):
   198        DoFn.TimerParam(BagStateSpec('elements', BytesCoder()))
   199  
   200    def test_stateful_dofn_detection(self):
   201      self.assertFalse(is_stateful_dofn(DoFn()))
   202      self.assertTrue(is_stateful_dofn(TestStatefulDoFn()))
   203  
   204    def test_good_signatures(self):
   205      class BasicStatefulDoFn(DoFn):
   206        BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
   207        EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK)
   208        EXPIRY_TIMER_FAMILY = TimerSpec('expiry_family_1', TimeDomain.WATERMARK)
   209  
   210        def process(
   211            self,
   212            element,
   213            buffer=DoFn.StateParam(BUFFER_STATE),
   214            timer1=DoFn.TimerParam(EXPIRY_TIMER),
   215            dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)):
   216          yield element
   217  
   218        @on_timer(EXPIRY_TIMER)
   219        def expiry_callback(self, element, timer=DoFn.TimerParam(EXPIRY_TIMER)):
   220          yield element
   221  
   222        @on_timer(EXPIRY_TIMER_FAMILY)
   223        def expiry_family_callback(
   224            self, element, dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)):
   225          yield element
   226  
   227      # Validate get_dofn_specs() and timer callbacks in
   228      # DoFnSignature.
   229      stateful_dofn = BasicStatefulDoFn()
   230      signature = self._validate_dofn(stateful_dofn)
   231      expected_specs = (
   232          set([BasicStatefulDoFn.BUFFER_STATE]),
   233          set([
   234              BasicStatefulDoFn.EXPIRY_TIMER,
   235              BasicStatefulDoFn.EXPIRY_TIMER_FAMILY
   236          ]),
   237      )
   238      self.assertEqual(expected_specs, get_dofn_specs(stateful_dofn))
   239      self.assertEqual(
   240          stateful_dofn.expiry_callback,
   241          signature.timer_methods[BasicStatefulDoFn.EXPIRY_TIMER].method_value)
   242      self.assertEqual(
   243          stateful_dofn.expiry_family_callback,
   244          signature.timer_methods[
   245              BasicStatefulDoFn.EXPIRY_TIMER_FAMILY].method_value)
   246  
   247      stateful_dofn = TestStatefulDoFn()
   248      signature = self._validate_dofn(stateful_dofn)
   249      expected_specs = (
   250          set([TestStatefulDoFn.BUFFER_STATE_1, TestStatefulDoFn.BUFFER_STATE_2]),
   251          set([
   252              TestStatefulDoFn.EXPIRY_TIMER_1,
   253              TestStatefulDoFn.EXPIRY_TIMER_2,
   254              TestStatefulDoFn.EXPIRY_TIMER_3,
   255              TestStatefulDoFn.EXPIRY_TIMER_FAMILY
   256          ]))
   257      self.assertEqual(expected_specs, get_dofn_specs(stateful_dofn))
   258      self.assertEqual(
   259          stateful_dofn.on_expiry_1,
   260          signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_1].method_value)
   261      self.assertEqual(
   262          stateful_dofn.on_expiry_2,
   263          signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_2].method_value)
   264      self.assertEqual(
   265          stateful_dofn.on_expiry_3,
   266          signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_3].method_value)
   267      self.assertEqual(
   268          stateful_dofn.on_expiry_family,
   269          signature.timer_methods[
   270              TestStatefulDoFn.EXPIRY_TIMER_FAMILY].method_value)
   271  
   272    def test_bad_signatures(self):
   273      # (1) The same state parameter is duplicated on the process method.
   274      class BadStatefulDoFn1(DoFn):
   275        BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
   276  
   277        def process(
   278            self,
   279            element,
   280            b1=DoFn.StateParam(BUFFER_STATE),
   281            b2=DoFn.StateParam(BUFFER_STATE)):
   282          yield element
   283  
   284      with self.assertRaises(ValueError):
   285        self._validate_dofn(BadStatefulDoFn1())
   286  
   287      # (2) The same timer parameter is duplicated on the process method.
   288      class BadStatefulDoFn2(DoFn):
   289        TIMER = TimerSpec('timer', TimeDomain.WATERMARK)
   290  
   291        def process(
   292            self, element, t1=DoFn.TimerParam(TIMER), t2=DoFn.TimerParam(TIMER)):
   293          yield element
   294  
   295      with self.assertRaises(ValueError):
   296        self._validate_dofn(BadStatefulDoFn2())
   297  
   298      # (3) The same state parameter is duplicated on the on_timer method.
   299      class BadStatefulDoFn3(DoFn):
   300        BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
   301        EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
   302        EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
   303  
   304        @on_timer(EXPIRY_TIMER_1)
   305        def expiry_callback(
   306            self,
   307            element,
   308            b1=DoFn.StateParam(BUFFER_STATE),
   309            b2=DoFn.StateParam(BUFFER_STATE)):
   310          yield element
   311  
   312      with self.assertRaises(ValueError):
   313        self._validate_dofn(BadStatefulDoFn3())
   314  
   315      # (4) The same timer parameter is duplicated on the on_timer method.
   316      class BadStatefulDoFn4(DoFn):
   317        BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
   318        EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
   319        EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
   320  
   321        @on_timer(EXPIRY_TIMER_1)
   322        def expiry_callback(
   323            self,
   324            element,
   325            t1=DoFn.TimerParam(EXPIRY_TIMER_2),
   326            t2=DoFn.TimerParam(EXPIRY_TIMER_2)):
   327          yield element
   328  
   329      with self.assertRaises(ValueError):
   330        self._validate_dofn(BadStatefulDoFn4())
   331  
   332      # (5) The same timer family parameter is duplicated on the process method.
   333      class BadStatefulDoFn5(DoFn):
   334        EXPIRY_TIMER_FAMILY = TimerSpec('dynamic_timer', TimeDomain.WATERMARK)
   335  
   336        def process(
   337            self,
   338            element,
   339            dynamic_timer_1=DoFn.TimerParam(EXPIRY_TIMER_FAMILY),
   340            dynamic_timer_2=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)):
   341          yield element
   342  
   343      with self.assertRaises(ValueError):
   344        self._validate_dofn(BadStatefulDoFn5())
   345  
   346    def test_validation_typos(self):
   347      # (1) Here, the user mistakenly used the same timer spec twice for two
   348      # different timer callbacks.
   349      with self.assertRaisesRegex(
   350          ValueError,
   351          r'Multiple on_timer callbacks registered for TimerSpec\(.*expiry1\).'):
   352  
   353        class StatefulDoFnWithTimerWithTypo1(DoFn):  # pylint: disable=unused-variable
   354          BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
   355          EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
   356          EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
   357  
   358          def process(self, element):
   359            pass
   360  
   361          @on_timer(EXPIRY_TIMER_1)
   362          def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
   363            yield 'expired1'
   364  
   365          # Note that we mistakenly associate this with the first timer.
   366          @on_timer(EXPIRY_TIMER_1)
   367          def on_expiry_2(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
   368            yield 'expired2'
   369  
   370      # (2) Here, the user mistakenly used the same callback name and overwrote
   371      # the first on_expiry_1 callback.
   372      class StatefulDoFnWithTimerWithTypo2(DoFn):
   373        BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
   374        EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
   375        EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
   376  
   377        def process(
   378            self,
   379            element,
   380            timer1=DoFn.TimerParam(EXPIRY_TIMER_1),
   381            timer2=DoFn.TimerParam(EXPIRY_TIMER_2)):
   382          pass
   383  
   384        @on_timer(EXPIRY_TIMER_1)
   385        def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
   386          yield 'expired1'
   387  
   388        # Note that we mistakenly reuse the "on_expiry_1" name; this is valid
   389        # syntactically in Python.
   390        @on_timer(EXPIRY_TIMER_2)
   391        def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):  # pylint: disable=function-redefined
   392          yield 'expired2'
   393  
   394        # Use a stable string value for matching.
   395        def __repr__(self):
   396          return 'StatefulDoFnWithTimerWithTypo2'
   397  
   398      dofn = StatefulDoFnWithTimerWithTypo2()
   399      with self.assertRaisesRegex(
   400          ValueError,
   401          (r'The on_timer callback for TimerSpec\(.*expiry1\) is not the '
   402           r'specified .on_expiry_1 method for DoFn '
   403           r'StatefulDoFnWithTimerWithTypo2 \(perhaps it was overwritten\?\).')):
   404        validate_stateful_dofn(dofn)
   405  
   406      # (2) Here, the user forgot to add an on_timer decorator for 'expiry2'
   407      class StatefulDoFnWithTimerWithTypo3(DoFn):
   408        BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
   409        EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK)
   410        EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK)
   411  
   412        def process(
   413            self,
   414            element,
   415            timer1=DoFn.TimerParam(EXPIRY_TIMER_1),
   416            timer2=DoFn.TimerParam(EXPIRY_TIMER_2)):
   417          pass
   418  
   419        @on_timer(EXPIRY_TIMER_1)
   420        def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
   421          yield 'expired1'
   422  
   423        def on_expiry_2(self, buffer_state=DoFn.StateParam(BUFFER_STATE)):
   424          yield 'expired2'
   425  
   426        # Use a stable string value for matching.
   427        def __repr__(self):
   428          return 'StatefulDoFnWithTimerWithTypo3'
   429  
   430      dofn = StatefulDoFnWithTimerWithTypo3()
   431      with self.assertRaisesRegex(
   432          ValueError,
   433          (r'DoFn StatefulDoFnWithTimerWithTypo3 has a TimerSpec without an '
   434           r'associated on_timer callback: TimerSpec\(.*expiry2\).')):
   435        validate_stateful_dofn(dofn)
   436  
   437  
   438  class StatefulDoFnOnDirectRunnerTest(unittest.TestCase):
   439    # pylint: disable=expression-not-assigned
   440    all_records = None  # type: List[Any]
   441  
   442    def setUp(self):
   443      # Use state on the TestCase class, since other references would be pickled
   444      # into a closure and not have the desired side effects.
   445      #
   446      # TODO(https://github.com/apache/beam/issues/18987): Use assert_that after
   447      # it works for the cases here in streaming mode.
   448      StatefulDoFnOnDirectRunnerTest.all_records = []
   449  
   450    def record_dofn(self):
   451      class RecordDoFn(DoFn):
   452        def process(self, element):
   453          StatefulDoFnOnDirectRunnerTest.all_records.append(element)
   454  
   455      return RecordDoFn()
   456  
   457    def test_simple_stateful_dofn(self):
   458      class SimpleTestStatefulDoFn(DoFn):
   459        BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
   460        EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
   461  
   462        def process(
   463            self,
   464            element,
   465            buffer=DoFn.StateParam(BUFFER_STATE),
   466            timer1=DoFn.TimerParam(EXPIRY_TIMER)):
   467          unused_key, value = element
   468          buffer.add(b'A' + str(value).encode('latin1'))
   469          timer1.set(20)
   470  
   471        @on_timer(EXPIRY_TIMER)
   472        def expiry_callback(
   473            self,
   474            buffer=DoFn.StateParam(BUFFER_STATE),
   475            timer=DoFn.TimerParam(EXPIRY_TIMER)):
   476          yield b''.join(sorted(buffer.read()))
   477  
   478      with TestPipeline() as p:
   479        test_stream = (
   480            TestStream().advance_watermark_to(10).add_elements(
   481                [1,
   482                 2]).add_elements([3]).advance_watermark_to(25).add_elements([4]))
   483        (
   484            p
   485            | test_stream
   486            | beam.Map(lambda x: ('mykey', x))
   487            | beam.ParDo(SimpleTestStatefulDoFn())
   488            | beam.ParDo(self.record_dofn()))
   489  
   490      # Two firings should occur: once after element 3 since the timer should
   491      # fire after the watermark passes time 20, and another time after element
   492      # 4, since the timer issued at that point should fire immediately.
   493      self.assertEqual([b'A1A2A3', b'A1A2A3A4'],
   494                       StatefulDoFnOnDirectRunnerTest.all_records)
   495  
   496    def test_clearing_bag_state(self):
   497      class BagStateClearingStatefulDoFn(beam.DoFn):
   498  
   499        BAG_STATE = BagStateSpec('bag_state', StrUtf8Coder())
   500        EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
   501        CLEAR_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK)
   502  
   503        def process(
   504            self,
   505            element,
   506            bag_state=beam.DoFn.StateParam(BAG_STATE),
   507            emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
   508            clear_timer=beam.DoFn.TimerParam(CLEAR_TIMER)):
   509          value = element[1]
   510          bag_state.add(value)
   511          clear_timer.set(100)
   512          emit_timer.set(1000)
   513  
   514        @on_timer(EMIT_TIMER)
   515        def emit_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)):
   516          for value in bag_state.read():
   517            yield value
   518          yield 'extra'
   519  
   520        @on_timer(CLEAR_TIMER)
   521        def clear_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)):
   522          bag_state.clear()
   523  
   524      with TestPipeline() as p:
   525        test_stream = (
   526            TestStream().advance_watermark_to(0).add_elements([
   527                ('key', 'value')
   528            ]).advance_watermark_to(100))
   529  
   530        _ = (
   531            p
   532            | test_stream
   533            | beam.ParDo(BagStateClearingStatefulDoFn())
   534            | beam.ParDo(self.record_dofn()))
   535  
   536      self.assertEqual(['extra'], StatefulDoFnOnDirectRunnerTest.all_records)
   537  
   538    def test_two_timers_one_function(self):
   539      class BagStateClearingStatefulDoFn(beam.DoFn):
   540  
   541        BAG_STATE = BagStateSpec('bag_state', StrUtf8Coder())
   542        EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
   543        EMIT_TWICE_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK)
   544  
   545        def process(
   546            self,
   547            element,
   548            bag_state=beam.DoFn.StateParam(BAG_STATE),
   549            emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
   550            emit_twice_timer=beam.DoFn.TimerParam(EMIT_TWICE_TIMER)):
   551          value = element[1]
   552          bag_state.add(value)
   553          emit_twice_timer.set(100)
   554          emit_timer.set(1000)
   555  
   556        @on_timer(EMIT_TWICE_TIMER)
   557        @on_timer(EMIT_TIMER)
   558        def emit_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)):
   559          for value in bag_state.read():
   560            yield value
   561  
   562      with TestPipeline() as p:
   563        test_stream = (
   564            TestStream().advance_watermark_to(0).add_elements([
   565                ('key', 'value')
   566            ]).advance_watermark_to(100))
   567  
   568        _ = (
   569            p
   570            | test_stream
   571            | beam.ParDo(BagStateClearingStatefulDoFn())
   572            | beam.ParDo(self.record_dofn()))
   573  
   574      self.assertEqual(['value', 'value'],
   575                       StatefulDoFnOnDirectRunnerTest.all_records)
   576  
   577    def test_simple_read_modify_write_stateful_dofn(self):
   578      class SimpleTestReadModifyWriteStatefulDoFn(DoFn):
   579        VALUE_STATE = ReadModifyWriteStateSpec('value', StrUtf8Coder())
   580  
   581        def process(self, element, last_element=DoFn.StateParam(VALUE_STATE)):
   582          last_element.write('%s:%s' % element)
   583          yield last_element.read()
   584  
   585      with TestPipeline() as p:
   586        test_stream = (
   587            TestStream().advance_watermark_to(0).add_elements([
   588                ('a', 1)
   589            ]).advance_watermark_to(10).add_elements([
   590                ('a', 3)
   591            ]).advance_watermark_to(20).add_elements([('a', 5)]))
   592        (
   593            p | test_stream
   594            | beam.ParDo(SimpleTestReadModifyWriteStatefulDoFn())
   595            | beam.ParDo(self.record_dofn()))
   596      self.assertEqual(['a:1', 'a:3', 'a:5'],
   597                       StatefulDoFnOnDirectRunnerTest.all_records)
   598  
   599    def test_clearing_read_modify_write_state(self):
   600      class SimpleClearingReadModifyWriteStatefulDoFn(DoFn):
   601        VALUE_STATE = ReadModifyWriteStateSpec('value', StrUtf8Coder())
   602  
   603        def process(self, element, last_element=DoFn.StateParam(VALUE_STATE)):
   604          value = last_element.read()
   605          if value is not None:
   606            yield value
   607          last_element.clear()
   608          last_element.write("%s:%s" % (last_element.read(), element[1]))
   609          if element[1] == 5:
   610            yield last_element.read()
   611  
   612      with TestPipeline() as p:
   613        test_stream = (
   614            TestStream().advance_watermark_to(0).add_elements([
   615                ('a', 1)
   616            ]).advance_watermark_to(10).add_elements([
   617                ('a', 3)
   618            ]).advance_watermark_to(20).add_elements([('a', 5)]))
   619        (
   620            p | test_stream
   621            | beam.ParDo(SimpleClearingReadModifyWriteStatefulDoFn())
   622            | beam.ParDo(self.record_dofn()))
   623      self.assertEqual(['None:1', 'None:3', 'None:5'],
   624                       StatefulDoFnOnDirectRunnerTest.all_records)
   625  
   626    def test_simple_set_stateful_dofn(self):
   627      class SimpleTestSetStatefulDoFn(DoFn):
   628        BUFFER_STATE = SetStateSpec('buffer', VarIntCoder())
   629        EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
   630  
   631        def process(
   632            self,
   633            element,
   634            buffer=DoFn.StateParam(BUFFER_STATE),
   635            timer1=DoFn.TimerParam(EXPIRY_TIMER)):
   636          unused_key, value = element
   637          buffer.add(value)
   638          timer1.set(20)
   639  
   640        @on_timer(EXPIRY_TIMER)
   641        def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE)):
   642          yield sorted(buffer.read())
   643  
   644      with TestPipeline() as p:
   645        test_stream = (
   646            TestStream().advance_watermark_to(10).add_elements(
   647                [1, 2, 3]).add_elements([2]).advance_watermark_to(24))
   648        (
   649            p
   650            | test_stream
   651            | beam.Map(lambda x: ('mykey', x))
   652            | beam.ParDo(SimpleTestSetStatefulDoFn())
   653            | beam.ParDo(self.record_dofn()))
   654  
   655      # Two firings should occur: once after element 3 since the timer should
   656      # fire after the watermark passes time 20, and another time after element
   657      # 4, since the timer issued at that point should fire immediately.
   658      self.assertEqual([[1, 2, 3]], StatefulDoFnOnDirectRunnerTest.all_records)
   659  
   660    def test_clearing_set_state(self):
   661      class SetStateClearingStatefulDoFn(beam.DoFn):
   662  
   663        SET_STATE = SetStateSpec('buffer', StrUtf8Coder())
   664        EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
   665        CLEAR_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK)
   666  
   667        def process(
   668            self,
   669            element,
   670            set_state=beam.DoFn.StateParam(SET_STATE),
   671            emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
   672            clear_timer=beam.DoFn.TimerParam(CLEAR_TIMER)):
   673          value = element[1]
   674          set_state.add(value)
   675          clear_timer.set(100)
   676          emit_timer.set(1000)
   677  
   678        @on_timer(EMIT_TIMER)
   679        def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
   680          for value in set_state.read():
   681            yield value
   682  
   683        @on_timer(CLEAR_TIMER)
   684        def clear_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
   685          set_state.clear()
   686          set_state.add('different-value')
   687  
   688      with TestPipeline() as p:
   689        test_stream = (
   690            TestStream().advance_watermark_to(0).add_elements([
   691                ('key1', 'value1')
   692            ]).advance_watermark_to(100))
   693  
   694        _ = (
   695            p
   696            | test_stream
   697            | beam.ParDo(SetStateClearingStatefulDoFn())
   698            | beam.ParDo(self.record_dofn()))
   699  
   700      self.assertEqual(['different-value'],
   701                       StatefulDoFnOnDirectRunnerTest.all_records)
   702  
   703    def test_stateful_set_state_portably(self):
   704      class SetStatefulDoFn(beam.DoFn):
   705  
   706        SET_STATE = SetStateSpec('buffer', VarIntCoder())
   707  
   708        def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)):
   709          _, value = element
   710          aggregated_value = 0
   711          set_state.add(value)
   712          for saved_value in set_state.read():
   713            aggregated_value += saved_value
   714          yield aggregated_value
   715  
   716      with TestPipeline() as p:
   717        values = p | beam.Create([('key', 1), ('key', 2), ('key', 3), ('key', 4),
   718                                  ('key', 3)],
   719                                 reshuffle=False)
   720        actual_values = (values | beam.ParDo(SetStatefulDoFn()))
   721        assert_that(actual_values, equal_to([1, 3, 6, 10, 10]))
   722  
   723    def test_stateful_set_state_clean_portably(self):
   724      class SetStateClearingStatefulDoFn(beam.DoFn):
   725  
   726        SET_STATE = SetStateSpec('buffer', VarIntCoder())
   727        EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
   728  
   729        def process(
   730            self,
   731            element,
   732            set_state=beam.DoFn.StateParam(SET_STATE),
   733            emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)):
   734          _, value = element
   735          set_state.add(value)
   736  
   737          all_elements = [element for element in set_state.read()]
   738  
   739          if len(all_elements) == 5:
   740            set_state.clear()
   741            set_state.add(100)
   742            emit_timer.set(1)
   743  
   744        @on_timer(EMIT_TIMER)
   745        def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
   746          yield sorted(set_state.read())
   747  
   748      with TestPipeline() as p:
   749        values = p | beam.Create([('key', 1), ('key', 2), ('key', 3), ('key', 4),
   750                                  ('key', 5)])
   751        actual_values = (
   752            values
   753            | beam.Map(lambda t: window.TimestampedValue(t, 1))
   754            | beam.WindowInto(window.FixedWindows(1))
   755            | beam.ParDo(SetStateClearingStatefulDoFn()))
   756  
   757        assert_that(actual_values, equal_to([[100]]))
   758  
   759    def test_stateful_dofn_nonkeyed_input(self):
   760      p = TestPipeline()
   761      values = p | beam.Create([1, 2, 3])
   762      with self.assertRaisesRegex(
   763          ValueError,
   764          ('Input elements to the transform .* with stateful DoFn must be '
   765           'key-value pairs.')):
   766        values | beam.ParDo(TestStatefulDoFn())
   767  
   768    def test_generate_sequence_with_realtime_timer(self):
   769      from apache_beam.transforms.combiners import CountCombineFn
   770  
   771      class GenerateRecords(beam.DoFn):
   772  
   773        EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.REAL_TIME)
   774        COUNT_STATE = CombiningValueStateSpec(
   775            'count_state', VarIntCoder(), CountCombineFn())
   776  
   777        def __init__(self, frequency, total_records):
   778          self.total_records = total_records
   779          self.frequency = frequency
   780  
   781        def process(self, element, emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)):
   782          # Processing time timers should be set on ABSOLUTE TIME.
   783          emit_timer.set(self.frequency)
   784          yield element[1]
   785  
   786        @on_timer(EMIT_TIMER)
   787        def emit_values(
   788            self,
   789            emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
   790            count_state=beam.DoFn.StateParam(COUNT_STATE)):
   791          count = count_state.read() or 0
   792          if self.total_records == count:
   793            return
   794  
   795          count_state.add(1)
   796          # Processing time timers should be set on ABSOLUTE TIME.
   797          emit_timer.set(count + 1 + self.frequency)
   798          yield 'value'
   799  
   800      TOTAL_RECORDS = 3
   801      FREQUENCY = 1
   802  
   803      test_stream = (
   804          TestStream().advance_watermark_to(0).add_elements([
   805              ('key', 0)
   806          ]).advance_processing_time(1)  # Timestamp: 1
   807          .add_elements([('key', 1)]).advance_processing_time(1)  # Timestamp: 2
   808          .add_elements([('key', 2)]).advance_processing_time(1)  # Timestamp: 3
   809          .add_elements([('key', 3)]))
   810  
   811      with beam.Pipeline(argv=['--streaming', '--runner=DirectRunner']) as p:
   812        _ = (
   813            p
   814            | test_stream
   815            | beam.ParDo(GenerateRecords(FREQUENCY, TOTAL_RECORDS))
   816            | beam.ParDo(self.record_dofn()))
   817  
   818      self.assertEqual(
   819          # 4 RECORDS go through process
   820          # 3 values are emitted from timer
   821          # Timestamp moves gradually.
   822          [0, 'value', 1, 'value', 2, 'value', 3],
   823          StatefulDoFnOnDirectRunnerTest.all_records)
   824  
   825    def test_simple_stateful_dofn_combining(self):
   826      class SimpleTestStatefulDoFn(DoFn):
   827        BUFFER_STATE = CombiningValueStateSpec(
   828            'buffer', ListCoder(VarIntCoder()), ToListCombineFn())
   829        EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK)
   830  
   831        def process(
   832            self,
   833            element,
   834            buffer=DoFn.StateParam(BUFFER_STATE),
   835            timer1=DoFn.TimerParam(EXPIRY_TIMER)):
   836          unused_key, value = element
   837          buffer.add(value)
   838          timer1.set(20)
   839  
   840        @on_timer(EXPIRY_TIMER)
   841        def expiry_callback(
   842            self,
   843            buffer=DoFn.StateParam(BUFFER_STATE),
   844            timer=DoFn.TimerParam(EXPIRY_TIMER)):
   845          yield ''.join(str(x) for x in sorted(buffer.read()))
   846  
   847      with TestPipeline() as p:
   848        test_stream = (
   849            TestStream().advance_watermark_to(10).add_elements(
   850                [1,
   851                 2]).add_elements([3]).advance_watermark_to(25).add_elements([4]))
   852        (
   853            p
   854            | test_stream
   855            | beam.Map(lambda x: ('mykey', x))
   856            | beam.ParDo(SimpleTestStatefulDoFn())
   857            | beam.ParDo(self.record_dofn()))
   858  
   859      self.assertEqual(['123', '1234'],
   860                       StatefulDoFnOnDirectRunnerTest.all_records)
   861  
   862    def test_timer_output_timestamp(self):
   863      class TimerEmittingStatefulDoFn(DoFn):
   864        EMIT_TIMER_1 = TimerSpec('emit1', TimeDomain.WATERMARK)
   865        EMIT_TIMER_2 = TimerSpec('emit2', TimeDomain.WATERMARK)
   866        EMIT_TIMER_3 = TimerSpec('emit3', TimeDomain.WATERMARK)
   867  
   868        def process(
   869            self,
   870            element,
   871            timer1=DoFn.TimerParam(EMIT_TIMER_1),
   872            timer2=DoFn.TimerParam(EMIT_TIMER_2),
   873            timer3=DoFn.TimerParam(EMIT_TIMER_3)):
   874          timer1.set(10)
   875          timer2.set(20)
   876          timer3.set(30)
   877  
   878        @on_timer(EMIT_TIMER_1)
   879        def emit_callback_1(self):
   880          yield 'timer1'
   881  
   882        @on_timer(EMIT_TIMER_2)
   883        def emit_callback_2(self):
   884          yield 'timer2'
   885  
   886        @on_timer(EMIT_TIMER_3)
   887        def emit_callback_3(self):
   888          yield 'timer3'
   889  
   890      class TimestampReifyingDoFn(DoFn):
   891        def process(self, element, ts=DoFn.TimestampParam):
   892          yield (element, int(ts))
   893  
   894      with TestPipeline() as p:
   895        test_stream = (TestStream().advance_watermark_to(10).add_elements([1]))
   896        (
   897            p
   898            | test_stream
   899            | beam.Map(lambda x: ('mykey', x))
   900            | beam.ParDo(TimerEmittingStatefulDoFn())
   901            | beam.ParDo(TimestampReifyingDoFn())
   902            | beam.ParDo(self.record_dofn()))
   903  
   904      self.assertEqual([('timer1', 10), ('timer2', 20), ('timer3', 30)],
   905                       sorted(StatefulDoFnOnDirectRunnerTest.all_records))
   906  
   907    def test_timer_output_timestamp_and_window(self):
   908      class TimerEmittingStatefulDoFn(DoFn):
   909        EMIT_TIMER_1 = TimerSpec('emit1', TimeDomain.WATERMARK)
   910  
   911        def process(self, element, timer1=DoFn.TimerParam(EMIT_TIMER_1)):
   912          timer1.set(10)
   913  
   914        @on_timer(EMIT_TIMER_1)
   915        def emit_callback_1(
   916            self,
   917            window=DoFn.WindowParam,
   918            ts=DoFn.TimestampParam,
   919            key=DoFn.KeyParam):
   920          yield (
   921              'timer1-{key}'.format(key=key),
   922              int(ts),
   923              int(window.start),
   924              int(window.end))
   925  
   926      pipeline_options = PipelineOptions()
   927      with TestPipeline(options=pipeline_options) as p:
   928        test_stream = (TestStream().advance_watermark_to(10).add_elements([1]))
   929        (
   930            p
   931            | test_stream
   932            | beam.Map(lambda x: ('mykey', x))
   933            | "window_into" >> beam.WindowInto(
   934                window.FixedWindows(5),
   935                accumulation_mode=trigger.AccumulationMode.DISCARDING)
   936            | beam.ParDo(TimerEmittingStatefulDoFn())
   937            | beam.ParDo(self.record_dofn()))
   938  
   939      self.assertEqual([('timer1-mykey', 10, 10, 15)],
   940                       sorted(StatefulDoFnOnDirectRunnerTest.all_records))
   941  
   942    def test_timer_default_tag(self):
   943      class DynamicTimerDoFn(DoFn):
   944        EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK)
   945  
   946        def process(self, element, emit=DoFn.TimerParam(EMIT_TIMER_FAMILY)):
   947          emit.set(10)
   948          emit.set(20, dynamic_timer_tag='')
   949  
   950        @on_timer(EMIT_TIMER_FAMILY)
   951        def emit_callback(
   952            self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
   953          yield (tag, ts)
   954  
   955      with TestPipeline() as p:
   956        test_stream = (TestStream().advance_watermark_to(10).add_elements(
   957            [1])).advance_watermark_to_infinity()
   958        (
   959            p
   960            | test_stream
   961            | beam.Map(lambda x: ('mykey', x))
   962            | beam.ParDo(DynamicTimerDoFn())
   963            | beam.ParDo(self.record_dofn()))
   964  
   965      self.assertEqual([('', 20)],
   966                       sorted(StatefulDoFnOnDirectRunnerTest.all_records))
   967  
   968    def test_dynamic_timer_simple_dofn(self):
   969      class DynamicTimerDoFn(DoFn):
   970        EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK)
   971  
   972        def process(self, element, emit=DoFn.TimerParam(EMIT_TIMER_FAMILY)):
   973          emit.set(10, dynamic_timer_tag='emit1')
   974          emit.set(20, dynamic_timer_tag='emit2')
   975          emit.set(30, dynamic_timer_tag='emit3')
   976  
   977        @on_timer(EMIT_TIMER_FAMILY)
   978        def emit_callback(
   979            self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
   980          yield (tag, ts)
   981  
   982      with TestPipeline() as p:
   983        test_stream = (TestStream().advance_watermark_to(10).add_elements(
   984            [1])).advance_watermark_to_infinity()
   985        (
   986            p
   987            | test_stream
   988            | beam.Map(lambda x: ('mykey', x))
   989            | beam.ParDo(DynamicTimerDoFn())
   990            | beam.ParDo(self.record_dofn()))
   991  
   992      self.assertEqual([('emit1', 10), ('emit2', 20), ('emit3', 30)],
   993                       sorted(StatefulDoFnOnDirectRunnerTest.all_records))
   994  
   995    @pytest.mark.no_xdist
   996    @pytest.mark.timeout(10)
   997    def test_dynamic_timer_clear_then_set_timer(self):
   998      class EmitTwoEvents(DoFn):
   999        EMIT_CLEAR_SET_TIMER = TimerSpec('emitclear', TimeDomain.WATERMARK)
  1000  
  1001        def process(self, element, emit=DoFn.TimerParam(EMIT_CLEAR_SET_TIMER)):
  1002          yield ('1', 'set')
  1003          emit.set(1)
  1004  
  1005        @on_timer(EMIT_CLEAR_SET_TIMER)
  1006        def emit_clear(self):
  1007          yield ('1', 'clear')
  1008  
  1009      class DynamicTimerDoFn(DoFn):
  1010        EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK)
  1011  
  1012        def process(self, element, emit=DoFn.TimerParam(EMIT_TIMER_FAMILY)):
  1013          if element[1] == 'set':
  1014            emit.set(10, dynamic_timer_tag='emit1')
  1015            emit.set(20, dynamic_timer_tag='emit2')
  1016          if element[1] == 'clear':
  1017            emit.set(30, dynamic_timer_tag='emit3')
  1018            emit.clear(dynamic_timer_tag='emit3')
  1019            emit.set(40, dynamic_timer_tag='emit3')
  1020          return []
  1021  
  1022        @on_timer(EMIT_TIMER_FAMILY)
  1023        def emit_callback(
  1024            self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
  1025          yield (tag, ts)
  1026  
  1027      with TestPipeline() as p:
  1028        res = (
  1029            p
  1030            | beam.Create([('1', 'impulse')])
  1031            | beam.ParDo(EmitTwoEvents())
  1032            | beam.ParDo(DynamicTimerDoFn()))
  1033        assert_that(res, equal_to([('emit1', 10), ('emit2', 20), ('emit3', 40)]))
  1034  
  1035    def test_dynamic_timer_clear_timer(self):
  1036      class DynamicTimerDoFn(DoFn):
  1037        EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK)
  1038  
  1039        def process(self, element, emit=DoFn.TimerParam(EMIT_TIMER_FAMILY)):
  1040          if element[1] == 'set':
  1041            emit.set(10, dynamic_timer_tag='emit1')
  1042            emit.set(20, dynamic_timer_tag='emit2')
  1043            emit.set(30, dynamic_timer_tag='emit3')
  1044          if element[1] == 'clear':
  1045            emit.clear(dynamic_timer_tag='emit3')
  1046  
  1047        @on_timer(EMIT_TIMER_FAMILY)
  1048        def emit_callback(
  1049            self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
  1050          yield (tag, ts)
  1051  
  1052      with TestPipeline() as p:
  1053        test_stream = (
  1054            TestStream().advance_watermark_to(5).add_elements(
  1055                ['set']).advance_watermark_to(10).add_elements(
  1056                    ['clear']).advance_watermark_to_infinity())
  1057        (
  1058            p
  1059            | test_stream
  1060            | beam.Map(lambda x: ('mykey', x))
  1061            | beam.ParDo(DynamicTimerDoFn())
  1062            | beam.ParDo(self.record_dofn()))
  1063  
  1064      self.assertEqual([('emit1', 10), ('emit2', 20)],
  1065                       sorted(StatefulDoFnOnDirectRunnerTest.all_records))
  1066  
  1067    def test_dynamic_timer_multiple(self):
  1068      class DynamicTimerDoFn(DoFn):
  1069        EMIT_TIMER_FAMILY1 = TimerSpec('emit_family_1', TimeDomain.WATERMARK)
  1070        EMIT_TIMER_FAMILY2 = TimerSpec('emit_family_2', TimeDomain.WATERMARK)
  1071  
  1072        def process(
  1073            self,
  1074            element,
  1075            emit1=DoFn.TimerParam(EMIT_TIMER_FAMILY1),
  1076            emit2=DoFn.TimerParam(EMIT_TIMER_FAMILY2)):
  1077          emit1.set(10, dynamic_timer_tag='emit11')
  1078          emit1.set(20, dynamic_timer_tag='emit12')
  1079          emit1.set(30, dynamic_timer_tag='emit13')
  1080          emit2.set(30, dynamic_timer_tag='emit21')
  1081          emit2.set(20, dynamic_timer_tag='emit22')
  1082          emit2.set(10, dynamic_timer_tag='emit23')
  1083  
  1084        @on_timer(EMIT_TIMER_FAMILY1)
  1085        def emit_callback(
  1086            self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
  1087          yield (tag, ts)
  1088  
  1089        @on_timer(EMIT_TIMER_FAMILY2)
  1090        def emit_callback_2(
  1091            self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
  1092          yield (tag, ts)
  1093  
  1094      with TestPipeline() as p:
  1095        test_stream = (
  1096            TestStream().advance_watermark_to(5).add_elements(
  1097                ['1']).advance_watermark_to_infinity())
  1098        (
  1099            p
  1100            | test_stream
  1101            | beam.Map(lambda x: ('mykey', x))
  1102            | beam.ParDo(DynamicTimerDoFn())
  1103            | beam.ParDo(self.record_dofn()))
  1104  
  1105      self.assertEqual([('emit11', 10), ('emit12', 20), ('emit13', 30),
  1106                        ('emit21', 30), ('emit22', 20), ('emit23', 10)],
  1107                       sorted(StatefulDoFnOnDirectRunnerTest.all_records))
  1108  
  1109    def test_dynamic_timer_and_simple_timer(self):
  1110      class DynamicTimerDoFn(DoFn):
  1111        EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK)
  1112        GC_TIMER = TimerSpec('gc', TimeDomain.WATERMARK)
  1113  
  1114        def process(
  1115            self,
  1116            element,
  1117            emit=DoFn.TimerParam(EMIT_TIMER_FAMILY),
  1118            gc=DoFn.TimerParam(GC_TIMER)):
  1119          emit.set(10, dynamic_timer_tag='emit1')
  1120          emit.set(20, dynamic_timer_tag='emit2')
  1121          emit.set(30, dynamic_timer_tag='emit3')
  1122          gc.set(40)
  1123  
  1124        @on_timer(EMIT_TIMER_FAMILY)
  1125        def emit_callback(
  1126            self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam):
  1127          yield (tag, ts)
  1128  
  1129        @on_timer(GC_TIMER)
  1130        def gc(self, ts=DoFn.TimestampParam):
  1131          yield ('gc', ts)
  1132  
  1133      with TestPipeline() as p:
  1134        test_stream = (
  1135            TestStream().advance_watermark_to(5).add_elements(
  1136                ['1']).advance_watermark_to_infinity())
  1137        (
  1138            p
  1139            | test_stream
  1140            | beam.Map(lambda x: ('mykey', x))
  1141            | beam.ParDo(DynamicTimerDoFn())
  1142            | beam.ParDo(self.record_dofn()))
  1143  
  1144      self.assertEqual([('emit1', 10), ('emit2', 20), ('emit3', 30), ('gc', 40)],
  1145                       sorted(StatefulDoFnOnDirectRunnerTest.all_records))
  1146  
  1147    def test_index_assignment(self):
  1148      class IndexAssigningStatefulDoFn(DoFn):
  1149        INDEX_STATE = CombiningValueStateSpec('index', sum)
  1150  
  1151        def process(self, element, state=DoFn.StateParam(INDEX_STATE)):
  1152          unused_key, value = element
  1153          current_index = state.read()
  1154          yield (value, current_index)
  1155          state.add(1)
  1156  
  1157      with TestPipeline() as p:
  1158        test_stream = (
  1159            TestStream().advance_watermark_to(10).add_elements([
  1160                'A', 'B'
  1161            ]).add_elements(['C']).advance_watermark_to(25).add_elements(['D']))
  1162        (
  1163            p
  1164            | test_stream
  1165            | beam.Map(lambda x: ('mykey', x))
  1166            | beam.ParDo(IndexAssigningStatefulDoFn())
  1167            | beam.ParDo(self.record_dofn()))
  1168  
  1169      self.assertEqual([('A', 0), ('B', 1), ('C', 2), ('D', 3)],
  1170                       StatefulDoFnOnDirectRunnerTest.all_records)
  1171  
  1172    def test_hash_join(self):
  1173      class HashJoinStatefulDoFn(DoFn):
  1174        BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
  1175        UNMATCHED_TIMER = TimerSpec('unmatched', TimeDomain.WATERMARK)
  1176  
  1177        def process(
  1178            self,
  1179            element,
  1180            state=DoFn.StateParam(BUFFER_STATE),
  1181            timer=DoFn.TimerParam(UNMATCHED_TIMER)):
  1182          key, value = element
  1183          existing_values = list(state.read())
  1184          if not existing_values:
  1185            state.add(value)
  1186            timer.set(100)
  1187          else:
  1188            yield b'Record<%s,%s,%s>' % (key, existing_values[0], value)
  1189            state.clear()
  1190            timer.clear()
  1191  
  1192        @on_timer(UNMATCHED_TIMER)
  1193        def expiry_callback(self, state=DoFn.StateParam(BUFFER_STATE)):
  1194          buffered = list(state.read())
  1195          assert len(buffered) == 1, buffered
  1196          state.clear()
  1197          yield b'Unmatched<%s>' % (buffered[0], )
  1198  
  1199      with TestPipeline() as p:
  1200        test_stream = (
  1201            TestStream().advance_watermark_to(10).add_elements([
  1202                (b'A', b'a'), (b'B', b'b')
  1203            ]).add_elements([
  1204                (b'A', b'aa'), (b'C', b'c')
  1205            ]).advance_watermark_to(25).add_elements([
  1206                (b'A', b'aaa'), (b'B', b'bb')
  1207            ]).add_elements([
  1208                (b'D', b'd'), (b'D', b'dd'), (b'D', b'ddd'), (b'D', b'dddd')
  1209            ]).advance_watermark_to(125).add_elements([(b'C', b'cc')]))
  1210        (
  1211            p
  1212            | test_stream
  1213            | beam.ParDo(HashJoinStatefulDoFn())
  1214            | beam.ParDo(self.record_dofn()))
  1215  
  1216      equal_to(StatefulDoFnOnDirectRunnerTest.all_records)([
  1217          b'Record<A,a,aa>',
  1218          b'Record<B,b,bb>',
  1219          b'Record<D,d,dd>',
  1220          b'Record<D,ddd,dddd>',
  1221          b'Unmatched<aaa>',
  1222          b'Unmatched<c>',
  1223          b'Unmatched<cc>'
  1224      ])
  1225  
  1226  
  1227  if __name__ == '__main__':
  1228    unittest.main()