github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/test_stream_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the test_stream module."""
    19  
    20  # pytype: skip-file
    21  
    22  import unittest
    23  
    24  import apache_beam as beam
    25  from apache_beam.options.pipeline_options import PipelineOptions
    26  from apache_beam.options.pipeline_options import StandardOptions
    27  from apache_beam.options.pipeline_options import TypeOptions
    28  from apache_beam.portability import common_urns
    29  from apache_beam.portability.api import beam_interactive_api_pb2
    30  from apache_beam.portability.api import beam_runner_api_pb2
    31  from apache_beam.testing.test_pipeline import TestPipeline
    32  from apache_beam.testing.test_stream import ElementEvent
    33  from apache_beam.testing.test_stream import OutputFormat
    34  from apache_beam.testing.test_stream import ProcessingTimeEvent
    35  from apache_beam.testing.test_stream import ReverseTestStream
    36  from apache_beam.testing.test_stream import TestStream
    37  from apache_beam.testing.test_stream import WatermarkEvent
    38  from apache_beam.testing.test_stream import WindowedValueHolder
    39  from apache_beam.testing.test_stream_service import TestStreamServiceController
    40  from apache_beam.testing.util import assert_that
    41  from apache_beam.testing.util import equal_to
    42  from apache_beam.testing.util import equal_to_per_window
    43  from apache_beam.transforms import trigger
    44  from apache_beam.transforms import window
    45  from apache_beam.transforms.window import FixedWindows
    46  from apache_beam.transforms.window import TimestampedValue
    47  from apache_beam.utils import timestamp
    48  from apache_beam.utils.timestamp import Timestamp
    49  from apache_beam.utils.windowed_value import PaneInfo
    50  from apache_beam.utils.windowed_value import PaneInfoTiming
    51  from apache_beam.utils.windowed_value import WindowedValue
    52  
    53  
    54  class TestStreamTest(unittest.TestCase):
    55    def test_basic_test_stream(self):
    56      test_stream = (TestStream()
    57                     .advance_watermark_to(0)
    58                     .add_elements([
    59                         'a',
    60                         WindowedValue('b', 3, []),
    61                         TimestampedValue('c', 6)])
    62                     .advance_processing_time(10)
    63                     .advance_watermark_to(8)
    64                     .add_elements(['d'])
    65                     .advance_watermark_to_infinity())  # yapf: disable
    66      self.assertEqual(
    67          test_stream._events,
    68          [
    69              WatermarkEvent(0),
    70              ElementEvent([
    71                  TimestampedValue('a', 0),
    72                  TimestampedValue('b', 3),
    73                  TimestampedValue('c', 6),
    74              ]),
    75              ProcessingTimeEvent(10),
    76              WatermarkEvent(8),
    77              ElementEvent([
    78                  TimestampedValue('d', 8),
    79              ]),
    80              WatermarkEvent(timestamp.MAX_TIMESTAMP),
    81          ])
    82  
    83    def test_test_stream_errors(self):
    84      with self.assertRaises(
    85          AssertionError, msg=('Watermark must strictly-monotonically advance.')):
    86        _ = (TestStream().advance_watermark_to(5).advance_watermark_to(4))
    87  
    88      with self.assertRaises(
    89          AssertionError,
    90          msg=('Must advance processing time by positive amount.')):
    91        _ = (TestStream().advance_processing_time(-1))
    92  
    93      with self.assertRaises(
    94          AssertionError,
    95          msg=('Element timestamp must be before timestamp.MAX_TIMESTAMP.')):
    96        _ = (
    97            TestStream().add_elements(
    98                [TimestampedValue('a', timestamp.MAX_TIMESTAMP)]))
    99  
   100    def test_basic_execution(self):
   101      test_stream = (TestStream()
   102                     .advance_watermark_to(10)
   103                     .add_elements(['a', 'b', 'c'])
   104                     .advance_watermark_to(20)
   105                     .add_elements(['d'])
   106                     .add_elements(['e'])
   107                     .advance_processing_time(10)
   108                     .advance_watermark_to(300)
   109                     .add_elements([TimestampedValue('late', 12)])
   110                     .add_elements([TimestampedValue('last', 310)])
   111                     .advance_watermark_to_infinity())  # yapf: disable
   112  
   113      class RecordFn(beam.DoFn):
   114        def process(
   115            self,
   116            element=beam.DoFn.ElementParam,
   117            timestamp=beam.DoFn.TimestampParam):
   118          yield (element, timestamp)
   119  
   120      options = PipelineOptions()
   121      options.view_as(StandardOptions).streaming = True
   122      with TestPipeline(options=options) as p:
   123        my_record_fn = RecordFn()
   124        records = p | test_stream | beam.ParDo(my_record_fn)
   125  
   126        assert_that(
   127            records,
   128            equal_to([
   129                ('a', timestamp.Timestamp(10)),
   130                ('b', timestamp.Timestamp(10)),
   131                ('c', timestamp.Timestamp(10)),
   132                ('d', timestamp.Timestamp(20)),
   133                ('e', timestamp.Timestamp(20)),
   134                ('late', timestamp.Timestamp(12)),
   135                ('last', timestamp.Timestamp(310)),
   136            ]))
   137  
   138    def test_multiple_outputs(self):
   139      """Tests that the TestStream supports emitting to multiple PCollections."""
   140      letters_elements = [
   141          TimestampedValue('a', 6),
   142          TimestampedValue('b', 7),
   143          TimestampedValue('c', 8),
   144      ]
   145      numbers_elements = [
   146          TimestampedValue('1', 11),
   147          TimestampedValue('2', 12),
   148          TimestampedValue('3', 13),
   149      ]
   150      test_stream = (TestStream()
   151          .advance_watermark_to(5, tag='letters')
   152          .add_elements(letters_elements, tag='letters')
   153          .advance_watermark_to(10, tag='numbers')
   154          .add_elements(numbers_elements, tag='numbers'))  # yapf: disable
   155  
   156      class RecordFn(beam.DoFn):
   157        def process(
   158            self,
   159            element=beam.DoFn.ElementParam,
   160            timestamp=beam.DoFn.TimestampParam):
   161          yield (element, timestamp)
   162  
   163      options = StandardOptions(streaming=True)
   164      p = TestPipeline(options=options)
   165  
   166      main = p | test_stream
   167      letters = main['letters'] | 'record letters' >> beam.ParDo(RecordFn())
   168      numbers = main['numbers'] | 'record numbers' >> beam.ParDo(RecordFn())
   169  
   170      assert_that(
   171          letters,
   172          equal_to([('a', Timestamp(6)), ('b', Timestamp(7)),
   173                    ('c', Timestamp(8))]),
   174          label='assert letters')
   175  
   176      assert_that(
   177          numbers,
   178          equal_to([('1', Timestamp(11)), ('2', Timestamp(12)),
   179                    ('3', Timestamp(13))]),
   180          label='assert numbers')
   181  
   182      p.run()
   183  
   184    def test_multiple_outputs_with_watermark_advancement(self):
   185      """Tests that the TestStream can independently control output watermarks."""
   186  
   187      # Purposely set the watermark of numbers to 20 then letters to 5 to test
   188      # that the watermark advancement is per PCollection.
   189      #
   190      # This creates two PCollections, (a, b, c) and (1, 2, 3). These will be
   191      # emitted at different times so that they will have different windows. The
   192      # watermark advancement is checked by checking their windows. If the
   193      # watermark does not advance, then the windows will be [-inf, -inf). If the
   194      # windows do not advance separately, then the PCollections will both
   195      # windowed in [15, 30).
   196      letters_elements = [
   197          TimestampedValue('a', 6),
   198          TimestampedValue('b', 7),
   199          TimestampedValue('c', 8),
   200      ]
   201      numbers_elements = [
   202          TimestampedValue('1', 21),
   203          TimestampedValue('2', 22),
   204          TimestampedValue('3', 23),
   205      ]
   206      test_stream = (TestStream()
   207                     .advance_watermark_to(0, tag='letters')
   208                     .advance_watermark_to(0, tag='numbers')
   209                     .advance_watermark_to(20, tag='numbers')
   210                     .advance_watermark_to(5, tag='letters')
   211                     .add_elements(letters_elements, tag='letters')
   212                     .advance_watermark_to(10, tag='letters')
   213                     .add_elements(numbers_elements, tag='numbers')
   214                     .advance_watermark_to(30, tag='numbers')) # yapf: disable
   215  
   216      options = StandardOptions(streaming=True)
   217      p = TestPipeline(options=options)
   218  
   219      main = p | test_stream
   220  
   221      # Use an AfterWatermark trigger with an early firing to test that the
   222      # watermark is advancing properly and that the element is being emitted in
   223      # the correct window.
   224      letters = (
   225          main['letters']
   226          | 'letter windows' >> beam.WindowInto(
   227              FixedWindows(15),
   228              trigger=trigger.AfterWatermark(early=trigger.AfterCount(1)),
   229              accumulation_mode=trigger.AccumulationMode.DISCARDING)
   230          | 'letter with key' >> beam.Map(lambda x: ('k', x))
   231          | 'letter gbk' >> beam.GroupByKey())
   232  
   233      numbers = (
   234          main['numbers']
   235          | 'number windows' >> beam.WindowInto(
   236              FixedWindows(15),
   237              trigger=trigger.AfterWatermark(early=trigger.AfterCount(1)),
   238              accumulation_mode=trigger.AccumulationMode.DISCARDING)
   239          | 'number with key' >> beam.Map(lambda x: ('k', x))
   240          | 'number gbk' >> beam.GroupByKey())
   241  
   242      # The letters were emitted when the watermark was at 5, thus we expect to
   243      # see the elements in the [0, 15) window. We used an early trigger to make
   244      # sure that the ON_TIME empty pane was also emitted with a TestStream.
   245      # This pane has no data because of the early trigger causes the elements to
   246      # fire before the end of the window and because the accumulation mode
   247      # discards any data after the trigger fired.
   248      expected_letters = {
   249          window.IntervalWindow(0, 15): [
   250              ('k', ['a', 'b', 'c']),
   251              ('k', []),
   252          ],
   253      }
   254  
   255      # Same here, except the numbers were emitted at watermark = 20, thus they
   256      # are in the [15, 30) window.
   257      expected_numbers = {
   258          window.IntervalWindow(15, 30): [
   259              ('k', ['1', '2', '3']),
   260              ('k', []),
   261          ],
   262      }
   263      assert_that(
   264          letters,
   265          equal_to_per_window(expected_letters),
   266          label='letters assert per window')
   267      assert_that(
   268          numbers,
   269          equal_to_per_window(expected_numbers),
   270          label='numbers assert per window')
   271  
   272      p.run()
   273  
   274    def test_dicts_not_interpreted_as_windowed_values(self):
   275      test_stream = (TestStream()
   276                     .advance_processing_time(10)
   277                     .advance_watermark_to(10)
   278                     .add_elements([{'a': 0, 'b': 1, 'c': 2}])
   279                     .advance_watermark_to_infinity())  # yapf: disable
   280  
   281      class RecordFn(beam.DoFn):
   282        def process(
   283            self,
   284            element=beam.DoFn.ElementParam,
   285            timestamp=beam.DoFn.TimestampParam):
   286          yield (element, timestamp)
   287  
   288      options = PipelineOptions()
   289      options.view_as(StandardOptions).streaming = True
   290      with TestPipeline(options=options) as p:
   291        my_record_fn = RecordFn()
   292        records = p | test_stream | beam.ParDo(my_record_fn)
   293  
   294        assert_that(
   295            records,
   296            equal_to([
   297                ({
   298                    'a': 0, 'b': 1, 'c': 2
   299                }, timestamp.Timestamp(10)),
   300            ]))
   301  
   302    def test_windowed_values_interpreted_correctly(self):
   303      windowed_value = WindowedValueHolder(
   304          WindowedValue(
   305              'a',
   306              Timestamp(5), [beam.window.IntervalWindow(5, 10)],
   307              PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0)))
   308      test_stream = (TestStream()
   309                     .advance_processing_time(10)
   310                     .advance_watermark_to(10)
   311                     .add_elements([windowed_value])
   312                     .advance_watermark_to_infinity())  # yapf: disable
   313  
   314      class RecordFn(beam.DoFn):
   315        def process(
   316            self,
   317            element=beam.DoFn.ElementParam,
   318            timestamp=beam.DoFn.TimestampParam,
   319            window=beam.DoFn.WindowParam):
   320          yield (element, timestamp, window)
   321  
   322      options = PipelineOptions()
   323      options.view_as(StandardOptions).streaming = True
   324      with TestPipeline(options=options) as p:
   325        my_record_fn = RecordFn()
   326        records = p | test_stream | beam.ParDo(my_record_fn)
   327  
   328        assert_that(
   329            records,
   330            equal_to([
   331                ('a', timestamp.Timestamp(5), beam.window.IntervalWindow(5, 10)),
   332            ]))
   333  
   334    def test_instance_check_windowed_value_holder(self):
   335      windowed_value = WindowedValue(
   336          'a',
   337          Timestamp(5), [beam.window.IntervalWindow(5, 10)],
   338          PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0))
   339      self.assertTrue(
   340          isinstance(WindowedValueHolder(windowed_value), WindowedValueHolder))
   341      self.assertTrue(
   342          isinstance(
   343              beam.Row(
   344                  windowed_value=windowed_value, urn=common_urns.coders.ROW.urn),
   345              WindowedValueHolder))
   346      self.assertFalse(
   347          isinstance(
   348              beam.Row(windowed_value=windowed_value), WindowedValueHolder))
   349      self.assertFalse(isinstance(windowed_value, WindowedValueHolder))
   350      self.assertFalse(
   351          isinstance(beam.Row(x=windowed_value), WindowedValueHolder))
   352      self.assertFalse(
   353          isinstance(beam.Row(windowed_value=1), WindowedValueHolder))
   354  
   355    def test_gbk_execution_no_triggers(self):
   356      test_stream = (
   357          TestStream().advance_watermark_to(10).add_elements([
   358              'a', 'b', 'c'
   359          ]).advance_watermark_to(20).add_elements(['d']).add_elements([
   360              'e'
   361          ]).advance_processing_time(10).advance_watermark_to(300).add_elements([
   362              TimestampedValue('late', 12)
   363          ]).add_elements([TimestampedValue('last', 310)
   364                           ]).advance_watermark_to_infinity())
   365  
   366      options = PipelineOptions()
   367      options.view_as(StandardOptions).streaming = True
   368      p = TestPipeline(options=options)
   369      records = (
   370          p
   371          | test_stream
   372          | beam.WindowInto(FixedWindows(15), allowed_lateness=300)
   373          | beam.Map(lambda x: ('k', x))
   374          | beam.GroupByKey())
   375  
   376      # TODO(https://github.com/apache/beam/issues/18441): timestamp assignment
   377      # for elements from a GBK should respect the TimestampCombiner.  The test
   378      # below should also verify the timestamps of the outputted elements once
   379      # this is implemented.
   380  
   381      # assert per window
   382      expected_window_to_elements = {
   383          window.IntervalWindow(0, 15): [
   384              ('k', ['a', 'b', 'c']),
   385              ('k', ['late']),
   386          ],
   387          window.IntervalWindow(15, 30): [
   388              ('k', ['d', 'e']),
   389          ],
   390          window.IntervalWindow(300, 315): [
   391              ('k', ['last']),
   392          ],
   393      }
   394      assert_that(
   395          records,
   396          equal_to_per_window(expected_window_to_elements),
   397          label='assert per window')
   398  
   399      p.run()
   400  
   401    def test_gbk_execution_after_watermark_trigger(self):
   402      test_stream = (TestStream()
   403          .advance_watermark_to(10)
   404          .add_elements([TimestampedValue('a', 11)])
   405          .advance_watermark_to(20)
   406          .add_elements([TimestampedValue('b', 21)])
   407          .advance_watermark_to_infinity())  # yapf: disable
   408  
   409      options = PipelineOptions()
   410      options.view_as(StandardOptions).streaming = True
   411      p = TestPipeline(options=options)
   412      records = (
   413          p  # pylint: disable=unused-variable
   414          | test_stream
   415          | beam.WindowInto(
   416              FixedWindows(15),
   417              trigger=trigger.AfterWatermark(early=trigger.AfterCount(1)),
   418              accumulation_mode=trigger.AccumulationMode.DISCARDING)
   419          | beam.Map(lambda x: ('k', x))
   420          | beam.GroupByKey())
   421  
   422      # TODO(https://github.com/apache/beam/issues/18441): timestamp assignment
   423      # for elements from a GBK should respect the TimestampCombiner.  The test
   424      # below should also verify the timestamps of the outputted elements once
   425      # this is implemented.
   426  
   427      # assert per window
   428      expected_window_to_elements = {
   429          window.IntervalWindow(0, 15): [('k', ['a']), ('k', [])],
   430          window.IntervalWindow(15, 30): [('k', ['b']), ('k', [])],
   431      }
   432      assert_that(
   433          records,
   434          equal_to_per_window(expected_window_to_elements),
   435          label='assert per window')
   436  
   437      p.run()
   438  
   439    def test_gbk_execution_after_processing_trigger_fired(self):
   440      """Advance TestClock to (X + delta) and see the pipeline does finish."""
   441      # TODO(mariagh): Add test_gbk_execution_after_processing_trigger_unfired
   442      # Advance TestClock to (X + delta) and see the pipeline does finish
   443      # Possibly to the framework trigger_transcripts.yaml
   444  
   445      test_stream = (TestStream()
   446          .advance_watermark_to(10)
   447          .add_elements(['a'])
   448          .advance_processing_time(5.1)
   449          .advance_watermark_to_infinity())  # yapf: disable
   450  
   451      options = PipelineOptions()
   452      options.view_as(StandardOptions).streaming = True
   453      options.view_as(TypeOptions).allow_unsafe_triggers = True
   454      p = TestPipeline(options=options)
   455      records = (
   456          p
   457          | test_stream
   458          | beam.WindowInto(
   459              beam.window.FixedWindows(15),
   460              trigger=trigger.AfterProcessingTime(5),
   461              accumulation_mode=trigger.AccumulationMode.DISCARDING)
   462          | beam.Map(lambda x: ('k', x))
   463          | beam.GroupByKey())
   464  
   465      # TODO(https://github.com/apache/beam/issues/18441): timestamp assignment
   466      # for elements from a GBK should respect the TimestampCombiner.  The test
   467      # below should also verify the timestamps of the outputted elements once
   468      # this is implemented.
   469  
   470      expected_window_to_elements = {
   471          window.IntervalWindow(0, 15): [('k', ['a'])],
   472      }
   473      assert_that(
   474          records,
   475          equal_to_per_window(expected_window_to_elements),
   476          label='assert per window')
   477  
   478      p.run()
   479  
   480    def test_basic_execution_batch_sideinputs(self):
   481      options = PipelineOptions()
   482      options.view_as(StandardOptions).streaming = True
   483      p = TestPipeline(options=options)
   484  
   485      main_stream = (p
   486                     | 'main TestStream' >> TestStream()
   487                     .advance_watermark_to(10)
   488                     .add_elements(['e'])
   489                     .advance_watermark_to_infinity())  # yapf: disable
   490      side = (
   491          p
   492          | beam.Create([2, 1, 4])
   493          | beam.Map(lambda t: window.TimestampedValue(t, t)))
   494  
   495      class RecordFn(beam.DoFn):
   496        def process(
   497            self,
   498            elm=beam.DoFn.ElementParam,
   499            ts=beam.DoFn.TimestampParam,
   500            side=beam.DoFn.SideInputParam):
   501          yield (elm, ts, sorted(side))
   502  
   503      records = (
   504          main_stream  # pylint: disable=unused-variable
   505          | beam.ParDo(RecordFn(), beam.pvalue.AsList(side)))
   506  
   507      assert_that(records, equal_to([('e', Timestamp(10), [1, 2, 4])]))
   508  
   509      p.run()
   510  
   511    def test_basic_execution_sideinputs(self):
   512      options = PipelineOptions()
   513      options.view_as(StandardOptions).streaming = True
   514      with TestPipeline(options=options) as p:
   515  
   516        test_stream = (p | TestStream()
   517            .advance_watermark_to(0, tag='side')
   518            .advance_watermark_to(10, tag='main')
   519            .add_elements(['e'], tag='main')
   520            .add_elements([window.TimestampedValue(2, 2)], tag='side')
   521            .add_elements([window.TimestampedValue(1, 1)], tag='side')
   522            .add_elements([window.TimestampedValue(7, 7)], tag='side')
   523            .add_elements([window.TimestampedValue(4, 4)], tag='side')
   524            ) # yapf: disable
   525  
   526        main_stream = test_stream['main']
   527        side_stream = test_stream['side']
   528  
   529        class RecordFn(beam.DoFn):
   530          def process(
   531              self,
   532              elm=beam.DoFn.ElementParam,
   533              ts=beam.DoFn.TimestampParam,
   534              side=beam.DoFn.SideInputParam):
   535            yield (elm, ts, side)
   536  
   537        records = (
   538            main_stream  # pylint: disable=unused-variable
   539            | beam.ParDo(RecordFn(), beam.pvalue.AsList(side_stream)))
   540  
   541        assert_that(records, equal_to([('e', Timestamp(10), [2, 1, 7, 4])]))
   542  
   543    def test_basic_execution_batch_sideinputs_fixed_windows(self):
   544      options = PipelineOptions()
   545      options.view_as(StandardOptions).streaming = True
   546      p = TestPipeline(options=options)
   547  
   548      main_stream = (
   549          p
   550          |
   551          'main TestStream' >> TestStream().advance_watermark_to(2).add_elements(
   552              ['a']).advance_watermark_to(4).add_elements(
   553                  ['b']).advance_watermark_to_infinity()
   554          | 'main window' >> beam.WindowInto(window.FixedWindows(1)))
   555      side = (
   556          p
   557          | beam.Create([2, 1, 4])
   558          | beam.Map(lambda t: window.TimestampedValue(t, t))
   559          | beam.WindowInto(window.FixedWindows(2)))
   560  
   561      class RecordFn(beam.DoFn):
   562        def process(
   563            self,
   564            elm=beam.DoFn.ElementParam,
   565            ts=beam.DoFn.TimestampParam,
   566            side=beam.DoFn.SideInputParam):
   567          yield (elm, ts, side)
   568  
   569      records = (
   570          main_stream  # pylint: disable=unused-variable
   571          | beam.ParDo(RecordFn(), beam.pvalue.AsList(side)))
   572  
   573      # assert per window
   574      expected_window_to_elements = {
   575          window.IntervalWindow(2, 3): [('a', Timestamp(2), [2])],
   576          window.IntervalWindow(4, 5): [('b', Timestamp(4), [4])]
   577      }
   578      assert_that(
   579          records,
   580          equal_to_per_window(expected_window_to_elements),
   581          label='assert per window')
   582  
   583      p.run()
   584  
   585    def test_basic_execution_sideinputs_fixed_windows(self):
   586      options = PipelineOptions()
   587      options.view_as(StandardOptions).streaming = True
   588      p = TestPipeline(options=options)
   589  
   590      test_stream = (p | TestStream()
   591          .advance_watermark_to(12, tag='side')
   592          .add_elements([window.TimestampedValue('s1', 10)], tag='side')
   593          .advance_watermark_to(20, tag='side')
   594          .add_elements([window.TimestampedValue('s2', 20)], tag='side')
   595  
   596          .advance_watermark_to(9, tag='main')
   597          .add_elements(['a1', 'a2', 'a3', 'a4'], tag='main')
   598          .add_elements(['b'], tag='main')
   599          .advance_watermark_to(18, tag='main')
   600          .add_elements('c', tag='main')
   601          ) # yapf: disable
   602  
   603      main_stream = (
   604          test_stream['main']
   605          | 'main windowInto' >> beam.WindowInto(window.FixedWindows(1)))
   606  
   607      side_stream = (
   608          test_stream['side']
   609          | 'side windowInto' >> beam.WindowInto(window.FixedWindows(3)))
   610  
   611      class RecordFn(beam.DoFn):
   612        def process(
   613            self,
   614            elm=beam.DoFn.ElementParam,
   615            ts=beam.DoFn.TimestampParam,
   616            side=beam.DoFn.SideInputParam):
   617          yield (elm, ts, side)
   618  
   619      records = (
   620          main_stream  # pylint: disable=unused-variable
   621          | beam.ParDo(RecordFn(), beam.pvalue.AsList(side_stream)))
   622  
   623      # assert per window
   624      expected_window_to_elements = {
   625          window.IntervalWindow(9, 10): [
   626              ('a1', Timestamp(9), ['s1']), ('a2', Timestamp(9), ['s1']),
   627              ('a3', Timestamp(9), ['s1']), ('a4', Timestamp(9), ['s1']),
   628              ('b', Timestamp(9), ['s1'])
   629          ],
   630          window.IntervalWindow(18, 19): [('c', Timestamp(18), ['s2'])],
   631      }
   632      assert_that(
   633          records,
   634          equal_to_per_window(expected_window_to_elements),
   635          label='assert per window')
   636  
   637      p.run()
   638  
   639    def test_roundtrip_proto(self):
   640      test_stream = (TestStream()
   641                     .advance_processing_time(1)
   642                     .advance_watermark_to(2)
   643                     .add_elements([1, 2, 3])) # yapf: disable
   644  
   645      p = TestPipeline(options=StandardOptions(streaming=True))
   646      p | test_stream
   647  
   648      pipeline_proto, context = p.to_runner_api(return_context=True)
   649  
   650      for t in pipeline_proto.components.transforms.values():
   651        if t.spec.urn == common_urns.primitives.TEST_STREAM.urn:
   652          test_stream_proto = t
   653  
   654      self.assertTrue(test_stream_proto)
   655      roundtrip_test_stream = TestStream().from_runner_api(
   656          test_stream_proto, context)
   657  
   658      self.assertListEqual(test_stream._events, roundtrip_test_stream._events)
   659      self.assertSetEqual(
   660          test_stream.output_tags, roundtrip_test_stream.output_tags)
   661      self.assertEqual(test_stream.coder, roundtrip_test_stream.coder)
   662  
   663    def test_roundtrip_proto_multi(self):
   664      test_stream = (TestStream()
   665                     .advance_processing_time(1)
   666                     .advance_watermark_to(2, tag='a')
   667                     .advance_watermark_to(3, tag='b')
   668                     .add_elements([1, 2, 3], tag='a')
   669                     .add_elements([4, 5, 6], tag='b')) # yapf: disable
   670  
   671      options = StandardOptions(streaming=True)
   672  
   673      p = TestPipeline(options=options)
   674      p | test_stream
   675  
   676      pipeline_proto, context = p.to_runner_api(return_context=True)
   677  
   678      for t in pipeline_proto.components.transforms.values():
   679        if t.spec.urn == common_urns.primitives.TEST_STREAM.urn:
   680          test_stream_proto = t
   681  
   682      self.assertTrue(test_stream_proto)
   683      roundtrip_test_stream = TestStream().from_runner_api(
   684          test_stream_proto, context)
   685  
   686      self.assertListEqual(test_stream._events, roundtrip_test_stream._events)
   687      self.assertSetEqual(
   688          test_stream.output_tags, roundtrip_test_stream.output_tags)
   689      self.assertEqual(test_stream.coder, roundtrip_test_stream.coder)
   690  
   691    def test_basic_execution_with_service(self):
   692      """Tests that the TestStream can correctly read from an RPC service.
   693      """
   694      coder = beam.coders.FastPrimitivesCoder()
   695  
   696      test_stream_events = (TestStream(coder=coder)
   697          .advance_watermark_to(10000)
   698          .add_elements(['a', 'b', 'c'])
   699          .advance_watermark_to(20000)
   700          .add_elements(['d'])
   701          .add_elements(['e'])
   702          .advance_processing_time(10)
   703          .advance_watermark_to(300000)
   704          .add_elements([TimestampedValue('late', 12000)])
   705          .add_elements([TimestampedValue('last', 310000)])
   706          .advance_watermark_to_infinity())._events  # yapf: disable
   707  
   708      test_stream_proto_events = [
   709          e.to_runner_api(coder) for e in test_stream_events
   710      ]
   711  
   712      class InMemoryEventReader:
   713        def read_multiple(self, unused_keys):
   714          for e in test_stream_proto_events:
   715            yield e
   716  
   717      service = TestStreamServiceController(reader=InMemoryEventReader())
   718      service.start()
   719  
   720      test_stream = TestStream(coder=coder, endpoint=service.endpoint)
   721  
   722      class RecordFn(beam.DoFn):
   723        def process(
   724            self,
   725            element=beam.DoFn.ElementParam,
   726            timestamp=beam.DoFn.TimestampParam):
   727          yield (element, timestamp)
   728  
   729      options = StandardOptions(streaming=True)
   730  
   731      p = TestPipeline(options=options)
   732      my_record_fn = RecordFn()
   733      records = p | test_stream | beam.ParDo(my_record_fn)
   734  
   735      assert_that(
   736          records,
   737          equal_to([
   738              ('a', timestamp.Timestamp(10)),
   739              ('b', timestamp.Timestamp(10)),
   740              ('c', timestamp.Timestamp(10)),
   741              ('d', timestamp.Timestamp(20)),
   742              ('e', timestamp.Timestamp(20)),
   743              ('late', timestamp.Timestamp(12)),
   744              ('last', timestamp.Timestamp(310)),
   745          ]))
   746  
   747      p.run()
   748  
   749  
   750  class ReverseTestStreamTest(unittest.TestCase):
   751    def test_basic_execution(self):
   752      test_stream = (TestStream()
   753                     .advance_watermark_to(0)
   754                     .advance_processing_time(5)
   755                     .add_elements(['a', 'b', 'c'])
   756                     .advance_watermark_to(2)
   757                     .advance_processing_time(1)
   758                     .advance_watermark_to(4)
   759                     .advance_processing_time(1)
   760                     .advance_watermark_to(6)
   761                     .advance_processing_time(1)
   762                     .advance_watermark_to(8)
   763                     .advance_processing_time(1)
   764                     .advance_watermark_to(10)
   765                     .advance_processing_time(1)
   766                     .add_elements([TimestampedValue('1', 15),
   767                                    TimestampedValue('2', 15),
   768                                    TimestampedValue('3', 15)]))  # yapf: disable
   769  
   770      options = StandardOptions(streaming=True)
   771      p = TestPipeline(options=options)
   772  
   773      records = (
   774          p
   775          | test_stream
   776          | ReverseTestStream(sample_resolution_sec=1, output_tag=None))
   777  
   778      assert_that(
   779          records,
   780          equal_to_per_window({
   781              beam.window.GlobalWindow(): [
   782                  [ProcessingTimeEvent(5), WatermarkEvent(0)],
   783                  [
   784                      ElementEvent([
   785                          TimestampedValue('a', 0),
   786                          TimestampedValue('b', 0),
   787                          TimestampedValue('c', 0)
   788                      ])
   789                  ],
   790                  [ProcessingTimeEvent(1), WatermarkEvent(2000000)],
   791                  [ProcessingTimeEvent(1), WatermarkEvent(4000000)],
   792                  [ProcessingTimeEvent(1), WatermarkEvent(6000000)],
   793                  [ProcessingTimeEvent(1), WatermarkEvent(8000000)],
   794                  [ProcessingTimeEvent(1), WatermarkEvent(10000000)],
   795                  [
   796                      ElementEvent([
   797                          TimestampedValue('1', 15),
   798                          TimestampedValue('2', 15),
   799                          TimestampedValue('3', 15)
   800                      ])
   801                  ],
   802              ],
   803          }))
   804  
   805      p.run()
   806  
   807    def test_windowing(self):
   808      test_stream = (TestStream()
   809                     .advance_watermark_to(0)
   810                     .add_elements(['a', 'b', 'c'])
   811                     .advance_processing_time(1)
   812                     .advance_processing_time(1)
   813                     .advance_processing_time(1)
   814                     .advance_processing_time(1)
   815                     .advance_processing_time(1)
   816                     .advance_watermark_to(5)
   817                     .add_elements(['1', '2', '3'])
   818                     .advance_processing_time(1)
   819                     .advance_watermark_to(6)
   820                     .advance_processing_time(1)
   821                     .advance_watermark_to(7)
   822                     .advance_processing_time(1)
   823                     .advance_watermark_to(8)
   824                     .advance_processing_time(1)
   825                     .advance_watermark_to(9)
   826                     .advance_processing_time(1)
   827                     .advance_watermark_to(10)
   828                     .advance_processing_time(1)
   829                     .advance_watermark_to(11)
   830                     .advance_processing_time(1)
   831                     .advance_watermark_to(12)
   832                     .advance_processing_time(1)
   833                     .advance_watermark_to(13)
   834                     .advance_processing_time(1)
   835                     .advance_watermark_to(14)
   836                     .advance_processing_time(1)
   837                     .advance_watermark_to(15)
   838                     .advance_processing_time(1)
   839                     )  # yapf: disable
   840  
   841      options = StandardOptions(streaming=True)
   842      p = TestPipeline(options=options)
   843  
   844      records = (
   845          p
   846          | test_stream
   847          | 'letter windows' >> beam.WindowInto(
   848              FixedWindows(5),
   849              accumulation_mode=trigger.AccumulationMode.DISCARDING)
   850          | 'letter with key' >> beam.Map(lambda x: ('k', x))
   851          | 'letter gbk' >> beam.GroupByKey()
   852          | ReverseTestStream(sample_resolution_sec=1, output_tag=None))
   853  
   854      assert_that(
   855          records,
   856          equal_to_per_window({
   857              beam.window.GlobalWindow(): [
   858                  [ProcessingTimeEvent(5), WatermarkEvent(4999998)],
   859                  [
   860                      ElementEvent(
   861                          [TimestampedValue(('k', ['a', 'b', 'c']), 4.999999)])
   862                  ],
   863                  [ProcessingTimeEvent(1), WatermarkEvent(5000000)],
   864                  [ProcessingTimeEvent(1), WatermarkEvent(6000000)],
   865                  [ProcessingTimeEvent(1), WatermarkEvent(7000000)],
   866                  [ProcessingTimeEvent(1), WatermarkEvent(8000000)],
   867                  [ProcessingTimeEvent(1), WatermarkEvent(9000000)],
   868                  [
   869                      ElementEvent(
   870                          [TimestampedValue(('k', ['1', '2', '3']), 9.999999)])
   871                  ],
   872                  [ProcessingTimeEvent(1), WatermarkEvent(10000000)],
   873                  [ProcessingTimeEvent(1), WatermarkEvent(11000000)],
   874                  [ProcessingTimeEvent(1), WatermarkEvent(12000000)],
   875                  [ProcessingTimeEvent(1), WatermarkEvent(13000000)],
   876                  [ProcessingTimeEvent(1), WatermarkEvent(14000000)],
   877                  [ProcessingTimeEvent(1), WatermarkEvent(15000000)],
   878              ],
   879          }))
   880  
   881      p.run()
   882  
   883    def test_basic_execution_in_records_format(self):
   884      test_stream = (TestStream()
   885                     .advance_watermark_to(0)
   886                     .advance_processing_time(5)
   887                     .add_elements(['a', 'b', 'c'])
   888                     .advance_watermark_to(2)
   889                     .advance_processing_time(1)
   890                     .advance_watermark_to(4)
   891                     .advance_processing_time(1)
   892                     .advance_watermark_to(6)
   893                     .advance_processing_time(1)
   894                     .advance_watermark_to(8)
   895                     .advance_processing_time(1)
   896                     .advance_watermark_to(10)
   897                     .advance_processing_time(1)
   898                     .add_elements([TimestampedValue('1', 15),
   899                                    TimestampedValue('2', 15),
   900                                    TimestampedValue('3', 15)]))  # yapf: disable
   901  
   902      options = StandardOptions(streaming=True)
   903      p = TestPipeline(options=options)
   904  
   905      coder = beam.coders.FastPrimitivesCoder()
   906      records = (
   907          p
   908          | test_stream
   909          | ReverseTestStream(
   910              sample_resolution_sec=1,
   911              coder=coder,
   912              output_format=OutputFormat.TEST_STREAM_FILE_RECORDS,
   913              output_tag=None)
   914          | 'stringify' >> beam.Map(str))
   915  
   916      assert_that(
   917          records,
   918          equal_to_per_window({
   919              beam.window.GlobalWindow(): [
   920                  str(beam_interactive_api_pb2.TestStreamFileHeader()),
   921                  str(
   922                      beam_interactive_api_pb2.TestStreamFileRecord(
   923                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
   924                          Event(
   925                              processing_time_event=beam_runner_api_pb2.
   926                              TestStreamPayload.Event.AdvanceProcessingTime(
   927                                  advance_duration=5000000)))),
   928                  str(
   929                      beam_interactive_api_pb2.TestStreamFileRecord(
   930                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
   931                          Event(
   932                              watermark_event=beam_runner_api_pb2.
   933                              TestStreamPayload.Event.AdvanceWatermark(
   934                                  new_watermark=0)))),
   935                  str(
   936                      beam_interactive_api_pb2.TestStreamFileRecord(
   937                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
   938                          Event(
   939                              element_event=beam_runner_api_pb2.TestStreamPayload.
   940                              Event.AddElements(
   941                                  elements=[
   942                                      beam_runner_api_pb2.TestStreamPayload.
   943                                      TimestampedElement(
   944                                          encoded_element=coder.encode('a'),
   945                                          timestamp=0),
   946                                      beam_runner_api_pb2.TestStreamPayload.
   947                                      TimestampedElement(
   948                                          encoded_element=coder.encode('b'),
   949                                          timestamp=0),
   950                                      beam_runner_api_pb2.TestStreamPayload.
   951                                      TimestampedElement(
   952                                          encoded_element=coder.encode('c'),
   953                                          timestamp=0),
   954                                  ])))),
   955                  str(
   956                      beam_interactive_api_pb2.TestStreamFileRecord(
   957                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
   958                          Event(
   959                              watermark_event=beam_runner_api_pb2.
   960                              TestStreamPayload.Event.AdvanceWatermark(
   961                                  new_watermark=2000000)))),
   962                  str(
   963                      beam_interactive_api_pb2.TestStreamFileRecord(
   964                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
   965                          Event(
   966                              processing_time_event=beam_runner_api_pb2.
   967                              TestStreamPayload.Event.AdvanceProcessingTime(
   968                                  advance_duration=1000000)))),
   969                  str(
   970                      beam_interactive_api_pb2.TestStreamFileRecord(
   971                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
   972                          Event(
   973                              watermark_event=beam_runner_api_pb2.
   974                              TestStreamPayload.Event.AdvanceWatermark(
   975                                  new_watermark=4000000)))),
   976                  str(
   977                      beam_interactive_api_pb2.TestStreamFileRecord(
   978                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
   979                          Event(
   980                              processing_time_event=beam_runner_api_pb2.
   981                              TestStreamPayload.Event.AdvanceProcessingTime(
   982                                  advance_duration=1000000)))),
   983                  str(
   984                      beam_interactive_api_pb2.TestStreamFileRecord(
   985                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
   986                          Event(
   987                              watermark_event=beam_runner_api_pb2.
   988                              TestStreamPayload.Event.AdvanceWatermark(
   989                                  new_watermark=6000000)))),
   990                  str(
   991                      beam_interactive_api_pb2.TestStreamFileRecord(
   992                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
   993                          Event(
   994                              processing_time_event=beam_runner_api_pb2.
   995                              TestStreamPayload.Event.AdvanceProcessingTime(
   996                                  advance_duration=1000000)))),
   997                  str(
   998                      beam_interactive_api_pb2.TestStreamFileRecord(
   999                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
  1000                          Event(
  1001                              watermark_event=beam_runner_api_pb2.
  1002                              TestStreamPayload.Event.AdvanceWatermark(
  1003                                  new_watermark=8000000)))),
  1004                  str(
  1005                      beam_interactive_api_pb2.TestStreamFileRecord(
  1006                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
  1007                          Event(
  1008                              processing_time_event=beam_runner_api_pb2.
  1009                              TestStreamPayload.Event.AdvanceProcessingTime(
  1010                                  advance_duration=1000000)))),
  1011                  str(
  1012                      beam_interactive_api_pb2.TestStreamFileRecord(
  1013                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
  1014                          Event(
  1015                              watermark_event=beam_runner_api_pb2.
  1016                              TestStreamPayload.Event.AdvanceWatermark(
  1017                                  new_watermark=10000000)))),
  1018                  str(
  1019                      beam_interactive_api_pb2.TestStreamFileRecord(
  1020                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
  1021                          Event(
  1022                              processing_time_event=beam_runner_api_pb2.
  1023                              TestStreamPayload.Event.AdvanceProcessingTime(
  1024                                  advance_duration=1000000)))),
  1025                  str(
  1026                      beam_interactive_api_pb2.TestStreamFileRecord(
  1027                          recorded_event=beam_runner_api_pb2.TestStreamPayload.
  1028                          Event(
  1029                              element_event=beam_runner_api_pb2.TestStreamPayload.
  1030                              Event.AddElements(
  1031                                  elements=[
  1032                                      beam_runner_api_pb2.TestStreamPayload.
  1033                                      TimestampedElement(
  1034                                          encoded_element=coder.encode('1'),
  1035                                          timestamp=15000000),
  1036                                      beam_runner_api_pb2.TestStreamPayload.
  1037                                      TimestampedElement(
  1038                                          encoded_element=coder.encode('2'),
  1039                                          timestamp=15000000),
  1040                                      beam_runner_api_pb2.TestStreamPayload.
  1041                                      TimestampedElement(
  1042                                          encoded_element=coder.encode('3'),
  1043                                          timestamp=15000000),
  1044                                  ])))),
  1045              ],
  1046          }))
  1047  
  1048      p.run()
  1049  
  1050  
  1051  if __name__ == '__main__':
  1052    unittest.main()