github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/trigger_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the triggering classes."""
    19  
    20  # pytype: skip-file
    21  
    22  import collections
    23  import json
    24  import os.path
    25  import pickle
    26  import random
    27  import unittest
    28  
    29  import yaml
    30  
    31  import apache_beam as beam
    32  from apache_beam import coders
    33  from apache_beam.options.pipeline_options import PipelineOptions
    34  from apache_beam.options.pipeline_options import StandardOptions
    35  from apache_beam.options.pipeline_options import TypeOptions
    36  from apache_beam.portability import common_urns
    37  from apache_beam.runners import pipeline_context
    38  from apache_beam.runners.direct.clock import TestClock
    39  from apache_beam.testing.test_pipeline import TestPipeline
    40  from apache_beam.testing.test_stream import TestStream
    41  from apache_beam.testing.util import assert_that
    42  from apache_beam.testing.util import equal_to
    43  from apache_beam.transforms import WindowInto
    44  from apache_beam.transforms import ptransform
    45  from apache_beam.transforms import trigger
    46  from apache_beam.transforms.core import Windowing
    47  from apache_beam.transforms.trigger import AccumulationMode
    48  from apache_beam.transforms.trigger import AfterAll
    49  from apache_beam.transforms.trigger import AfterAny
    50  from apache_beam.transforms.trigger import AfterCount
    51  from apache_beam.transforms.trigger import AfterEach
    52  from apache_beam.transforms.trigger import AfterProcessingTime
    53  from apache_beam.transforms.trigger import AfterWatermark
    54  from apache_beam.transforms.trigger import Always
    55  from apache_beam.transforms.trigger import DataLossReason
    56  from apache_beam.transforms.trigger import DefaultTrigger
    57  from apache_beam.transforms.trigger import GeneralTriggerDriver
    58  from apache_beam.transforms.trigger import InMemoryUnmergedState
    59  from apache_beam.transforms.trigger import Repeatedly
    60  from apache_beam.transforms.trigger import TriggerFn
    61  from apache_beam.transforms.trigger import _Never
    62  from apache_beam.transforms.window import FixedWindows
    63  from apache_beam.transforms.window import GlobalWindows
    64  from apache_beam.transforms.window import IntervalWindow
    65  from apache_beam.transforms.window import Sessions
    66  from apache_beam.transforms.window import TimestampCombiner
    67  from apache_beam.transforms.window import TimestampedValue
    68  from apache_beam.transforms.window import WindowedValue
    69  from apache_beam.transforms.window import WindowFn
    70  from apache_beam.utils.timestamp import MAX_TIMESTAMP
    71  from apache_beam.utils.timestamp import MIN_TIMESTAMP
    72  from apache_beam.utils.timestamp import Duration
    73  from apache_beam.utils.windowed_value import PaneInfoTiming
    74  
    75  
    76  class CustomTimestampingFixedWindowsWindowFn(FixedWindows):
    77    """WindowFn for testing custom timestamping."""
    78    def get_transformed_output_time(self, unused_window, input_timestamp):
    79      return input_timestamp + 100
    80  
    81  
    82  class TriggerTest(unittest.TestCase):
    83    def run_trigger_simple(
    84        self,
    85        window_fn,
    86        trigger_fn,
    87        accumulation_mode,
    88        timestamped_data,
    89        expected_panes,
    90        *groupings,
    91        **kwargs):
    92      # Groupings is a list of integers indicating the (uniform) size of bundles
    93      # to try. For example, if timestamped_data has elements [a, b, c, d, e]
    94      # then groupings=(5, 2) would first run the test with everything in the same
    95      # bundle, and then re-run the test with bundling [a, b], [c, d], [e].
    96      # A negative value will reverse the order, e.g. -2 would result in bundles
    97      # [e, d], [c, b], [a].  This is useful for deterministic triggers in testing
    98      # that the output is not a function of ordering or bundling.
    99      # If empty, defaults to bundles of size 1 in the given order.
   100      late_data = kwargs.pop('late_data', [])
   101      assert not kwargs
   102  
   103      def bundle_data(data, size):
   104        if size < 0:
   105          data = list(data)[::-1]
   106          size = -size
   107        bundle = []
   108        for timestamp, elem in data:
   109          windows = window_fn.assign(WindowFn.AssignContext(timestamp, elem))
   110          bundle.append(WindowedValue(elem, timestamp, windows))
   111          if len(bundle) == size:
   112            yield bundle
   113            bundle = []
   114        if bundle:
   115          yield bundle
   116  
   117      if not groupings:
   118        groupings = [1]
   119      for group_by in groupings:
   120        self.run_trigger(
   121            window_fn,
   122            trigger_fn,
   123            accumulation_mode,
   124            bundle_data(timestamped_data, group_by),
   125            bundle_data(late_data, group_by),
   126            expected_panes)
   127  
   128    def run_trigger(
   129        self,
   130        window_fn,
   131        trigger_fn,
   132        accumulation_mode,
   133        bundles,
   134        late_bundles,
   135        expected_panes):
   136      actual_panes = collections.defaultdict(list)
   137      allowed_lateness = Duration(
   138          micros=int(common_urns.constants.MAX_TIMESTAMP_MILLIS.constant) * 1000)
   139      driver = GeneralTriggerDriver(
   140          Windowing(
   141              window_fn,
   142              trigger_fn,
   143              accumulation_mode,
   144              allowed_lateness=allowed_lateness),
   145          TestClock())
   146      state = InMemoryUnmergedState()
   147  
   148      for bundle in bundles:
   149        for wvalue in driver.process_elements(state,
   150                                              bundle,
   151                                              MIN_TIMESTAMP,
   152                                              MIN_TIMESTAMP):
   153          window, = wvalue.windows
   154          self.assertEqual(window.max_timestamp(), wvalue.timestamp)
   155          actual_panes[window].append(set(wvalue.value))
   156  
   157      while state.timers:
   158        for timer_window, (name, time_domain, timestamp,
   159                           _) in state.get_and_clear_timers():
   160          for wvalue in driver.process_timer(timer_window,
   161                                             name,
   162                                             time_domain,
   163                                             timestamp,
   164                                             state,
   165                                             MIN_TIMESTAMP):
   166            window, = wvalue.windows
   167            self.assertEqual(window.max_timestamp(), wvalue.timestamp)
   168            actual_panes[window].append(set(wvalue.value))
   169  
   170      for bundle in late_bundles:
   171        for wvalue in driver.process_elements(state,
   172                                              bundle,
   173                                              MAX_TIMESTAMP,
   174                                              MAX_TIMESTAMP):
   175          window, = wvalue.windows
   176          self.assertEqual(window.max_timestamp(), wvalue.timestamp)
   177          actual_panes[window].append(set(wvalue.value))
   178  
   179        while state.timers:
   180          for timer_window, (name, time_domain, timestamp,
   181                             _) in state.get_and_clear_timers():
   182            for wvalue in driver.process_timer(timer_window,
   183                                               name,
   184                                               time_domain,
   185                                               timestamp,
   186                                               state,
   187                                               MAX_TIMESTAMP):
   188              window, = wvalue.windows
   189              self.assertEqual(window.max_timestamp(), wvalue.timestamp)
   190              actual_panes[window].append(set(wvalue.value))
   191  
   192      self.assertEqual(expected_panes, actual_panes)
   193  
   194    def test_fixed_watermark(self):
   195      self.run_trigger_simple(
   196          FixedWindows(10),  # pyformat break
   197          AfterWatermark(),
   198          AccumulationMode.ACCUMULATING,
   199          [(1, 'a'), (2, 'b'), (13, 'c')],
   200          {IntervalWindow(0, 10): [set('ab')],
   201           IntervalWindow(10, 20): [set('c')]},
   202          1,
   203          2,
   204          3,
   205          -3,
   206          -2,
   207          -1)
   208  
   209    def test_fixed_watermark_with_early(self):
   210      self.run_trigger_simple(
   211          FixedWindows(10),  # pyformat break
   212          AfterWatermark(early=AfterCount(2)),
   213          AccumulationMode.ACCUMULATING,
   214          [(1, 'a'), (2, 'b'), (3, 'c')],
   215          {IntervalWindow(0, 10): [set('ab'), set('abc')]},
   216          2)
   217      self.run_trigger_simple(
   218          FixedWindows(10),  # pyformat break
   219          AfterWatermark(early=AfterCount(2)),
   220          AccumulationMode.ACCUMULATING,
   221          [(1, 'a'), (2, 'b'), (3, 'c')],
   222          {IntervalWindow(0, 10): [set('abc'), set('abc')]},
   223          3)
   224  
   225    def test_fixed_watermark_with_early_late(self):
   226      self.run_trigger_simple(
   227          FixedWindows(100),  # pyformat break
   228          AfterWatermark(early=AfterCount(3),
   229                         late=AfterCount(2)),
   230          AccumulationMode.DISCARDING,
   231          zip(range(9), 'abcdefghi'),
   232          {IntervalWindow(0, 100): [
   233              set('abcd'), set('efgh'),  # early
   234              set('i'),                  # on time
   235              set('vw'), set('xy')       # late
   236              ]},
   237          2,
   238          late_data=zip(range(5), 'vwxyz'))
   239  
   240    def test_sessions_watermark_with_early_late(self):
   241      self.run_trigger_simple(
   242          Sessions(10),  # pyformat break
   243          AfterWatermark(early=AfterCount(2),
   244                         late=AfterCount(1)),
   245          AccumulationMode.ACCUMULATING,
   246          [(1, 'a'), (15, 'b'), (7, 'c'), (30, 'd')],
   247          {
   248              IntervalWindow(1, 25): [
   249                  set('abc'),                # early
   250                  set('abc'),                # on time
   251                  set('abcxy')               # late
   252              ],
   253              IntervalWindow(30, 40): [
   254                  set('d'),                  # on time
   255              ],
   256              IntervalWindow(1, 40): [
   257                  set('abcdxyz')             # late
   258              ],
   259          },
   260          2,
   261          late_data=[(1, 'x'), (2, 'y'), (21, 'z')])
   262  
   263    def test_fixed_after_count(self):
   264      self.run_trigger_simple(
   265          FixedWindows(10),  # pyformat break
   266          AfterCount(2),
   267          AccumulationMode.ACCUMULATING,
   268          [(1, 'a'), (2, 'b'), (3, 'c'), (11, 'z')],
   269          {IntervalWindow(0, 10): [set('ab')]},
   270          1,
   271          2)
   272      self.run_trigger_simple(
   273          FixedWindows(10),  # pyformat break
   274          AfterCount(2),
   275          AccumulationMode.ACCUMULATING,
   276          [(1, 'a'), (2, 'b'), (3, 'c'), (11, 'z')],
   277          {IntervalWindow(0, 10): [set('abc')]},
   278          3,
   279          4)
   280  
   281    def test_fixed_after_first(self):
   282      self.run_trigger_simple(
   283          FixedWindows(10),  # pyformat break
   284          AfterAny(AfterCount(2), AfterWatermark()),
   285          AccumulationMode.ACCUMULATING,
   286          [(1, 'a'), (2, 'b'), (3, 'c')],
   287          {IntervalWindow(0, 10): [set('ab')]},
   288          1,
   289          2)
   290      self.run_trigger_simple(
   291          FixedWindows(10),  # pyformat break
   292          AfterAny(AfterCount(5), AfterWatermark()),
   293          AccumulationMode.ACCUMULATING,
   294          [(1, 'a'), (2, 'b'), (3, 'c')],
   295          {IntervalWindow(0, 10): [set('abc')]},
   296          1,
   297          2,
   298          late_data=[(1, 'x'), (2, 'y'), (3, 'z')])
   299  
   300    def test_repeatedly_after_first(self):
   301      self.run_trigger_simple(
   302          FixedWindows(100),  # pyformat break
   303          Repeatedly(AfterAny(AfterCount(3), AfterWatermark())),
   304          AccumulationMode.ACCUMULATING,
   305          zip(range(7), 'abcdefg'),
   306          {IntervalWindow(0, 100): [
   307              set('abc'),
   308              set('abcdef'),
   309              set('abcdefg'),
   310              set('abcdefgx'),
   311              set('abcdefgxy'),
   312              set('abcdefgxyz')]},
   313          1,
   314          late_data=zip(range(3), 'xyz'))
   315  
   316    def test_sessions_after_all(self):
   317      self.run_trigger_simple(
   318          Sessions(10),  # pyformat break
   319          AfterAll(AfterCount(2), AfterWatermark()),
   320          AccumulationMode.ACCUMULATING,
   321          [(1, 'a'), (2, 'b'), (3, 'c')],
   322          {IntervalWindow(1, 13): [set('abc')]},
   323          1,
   324          2)
   325      self.run_trigger_simple(
   326          Sessions(10),  # pyformat break
   327          AfterAll(AfterCount(5), AfterWatermark()),
   328          AccumulationMode.ACCUMULATING,
   329          [(1, 'a'), (2, 'b'), (3, 'c')],
   330          {IntervalWindow(1, 13): [set('abcxy')]},
   331          1,
   332          2,
   333          late_data=[(1, 'x'), (2, 'y'), (3, 'z')])
   334  
   335    def test_sessions_default(self):
   336      self.run_trigger_simple(
   337          Sessions(10),  # pyformat break
   338          DefaultTrigger(),
   339          AccumulationMode.ACCUMULATING,
   340          [(1, 'a'), (2, 'b')],
   341          {IntervalWindow(1, 12): [set('ab')]},
   342          1,
   343          2,
   344          -2,
   345          -1)
   346  
   347      self.run_trigger_simple(
   348          Sessions(10),  # pyformat break
   349          AfterWatermark(),
   350          AccumulationMode.ACCUMULATING,
   351          [(1, 'a'), (2, 'b'), (15, 'c'), (16, 'd'), (30, 'z'), (9, 'e'),
   352           (10, 'f'), (30, 'y')],
   353          {IntervalWindow(1, 26): [set('abcdef')],
   354           IntervalWindow(30, 40): [set('yz')]},
   355          1,
   356          2,
   357          3,
   358          4,
   359          5,
   360          6,
   361          -4,
   362          -2,
   363          -1)
   364  
   365    def test_sessions_watermark(self):
   366      self.run_trigger_simple(
   367          Sessions(10),  # pyformat break
   368          AfterWatermark(),
   369          AccumulationMode.ACCUMULATING,
   370          [(1, 'a'), (2, 'b')],
   371          {IntervalWindow(1, 12): [set('ab')]},
   372          1,
   373          2,
   374          -2,
   375          -1)
   376  
   377    def test_sessions_after_count(self):
   378      self.run_trigger_simple(
   379          Sessions(10),  # pyformat break
   380          AfterCount(2),
   381          AccumulationMode.ACCUMULATING,
   382          [(1, 'a'), (15, 'b'), (6, 'c'), (30, 's'), (31, 't'), (50, 'z'),
   383           (50, 'y')],
   384          {IntervalWindow(1, 25): [set('abc')],
   385           IntervalWindow(30, 41): [set('st')],
   386           IntervalWindow(50, 60): [set('yz')]},
   387          1,
   388          2,
   389          3)
   390  
   391    def test_sessions_repeatedly_after_count(self):
   392      self.run_trigger_simple(
   393          Sessions(10),  # pyformat break
   394          Repeatedly(AfterCount(2)),
   395          AccumulationMode.ACCUMULATING,
   396          [(1, 'a'), (15, 'b'), (6, 'c'), (2, 'd'), (7, 'e')],
   397          {IntervalWindow(1, 25): [set('abc'), set('abcde')]},
   398          1,
   399          3)
   400      self.run_trigger_simple(
   401          Sessions(10),  # pyformat break
   402          Repeatedly(AfterCount(2)),
   403          AccumulationMode.DISCARDING,
   404          [(1, 'a'), (15, 'b'), (6, 'c'), (2, 'd'), (7, 'e')],
   405          {IntervalWindow(1, 25): [set('abc'), set('de')]},
   406          1,
   407          3)
   408  
   409    def test_sessions_after_each(self):
   410      self.run_trigger_simple(
   411          Sessions(10),  # pyformat break
   412          AfterEach(AfterCount(2), AfterCount(3)),
   413          AccumulationMode.ACCUMULATING,
   414          zip(range(10), 'abcdefghij'),
   415          {IntervalWindow(0, 11): [set('ab')],
   416           IntervalWindow(0, 15): [set('abcdef')]},
   417          2)
   418  
   419      self.run_trigger_simple(
   420          Sessions(10),  # pyformat break
   421          Repeatedly(AfterEach(AfterCount(2), AfterCount(3))),
   422          AccumulationMode.ACCUMULATING,
   423          zip(range(10), 'abcdefghij'),
   424          {IntervalWindow(0, 11): [set('ab')],
   425           IntervalWindow(0, 15): [set('abcdef')],
   426           IntervalWindow(0, 17): [set('abcdefgh')]},
   427          2)
   428  
   429    def test_picklable_output(self):
   430      global_window = (trigger.GlobalWindow(), )
   431      driver = trigger.BatchGlobalTriggerDriver()
   432      unpicklable = (WindowedValue(k, 0, global_window) for k in range(10))
   433      with self.assertRaises(TypeError):
   434        pickle.dumps(unpicklable)
   435      for unwindowed in driver.process_elements(None, unpicklable, None, None):
   436        self.assertEqual(
   437            pickle.loads(pickle.dumps(unwindowed)).value, list(range(10)))
   438  
   439  
   440  class MayLoseDataTest(unittest.TestCase):
   441    def _test(self, trigger, lateness, expected):
   442      windowing = WindowInto(
   443          GlobalWindows(),
   444          trigger=trigger,
   445          accumulation_mode=AccumulationMode.ACCUMULATING,
   446          allowed_lateness=lateness).windowing
   447      self.assertEqual(trigger.may_lose_data(windowing), expected)
   448  
   449    def test_default_trigger(self):
   450      self._test(DefaultTrigger(), 0, DataLossReason.NO_POTENTIAL_LOSS)
   451  
   452    def test_after_processing(self):
   453      self._test(AfterProcessingTime(42), 0, DataLossReason.MAY_FINISH)
   454  
   455    def test_always(self):
   456      self._test(Always(), 0, DataLossReason.NO_POTENTIAL_LOSS)
   457  
   458    def test_never(self):
   459      self._test(_Never(), 0, DataLossReason.NO_POTENTIAL_LOSS)
   460  
   461    def test_after_watermark_no_allowed_lateness(self):
   462      self._test(AfterWatermark(), 0, DataLossReason.NO_POTENTIAL_LOSS)
   463  
   464    def test_after_watermark_no_late_trigger(self):
   465      self._test(AfterWatermark(), 60, DataLossReason.MAY_FINISH)
   466  
   467    def test_after_watermark_no_allowed_lateness_safe_late(self):
   468      self._test(
   469          AfterWatermark(late=DefaultTrigger()),
   470          0,
   471          DataLossReason.NO_POTENTIAL_LOSS)
   472  
   473    def test_after_watermark_allowed_lateness_safe_late(self):
   474      self._test(
   475          AfterWatermark(late=DefaultTrigger()),
   476          60,
   477          DataLossReason.NO_POTENTIAL_LOSS)
   478  
   479    def test_after_count(self):
   480      self._test(AfterCount(42), 0, DataLossReason.MAY_FINISH)
   481  
   482    def test_repeatedly_safe_underlying(self):
   483      self._test(
   484          Repeatedly(DefaultTrigger()), 0, DataLossReason.NO_POTENTIAL_LOSS)
   485  
   486    def test_repeatedly_unsafe_underlying(self):
   487      self._test(Repeatedly(AfterCount(42)), 0, DataLossReason.NO_POTENTIAL_LOSS)
   488  
   489    def test_after_any_one_may_finish(self):
   490      self._test(
   491          AfterAny(AfterCount(42), DefaultTrigger()),
   492          0,
   493          DataLossReason.MAY_FINISH)
   494  
   495    def test_after_any_all_safe(self):
   496      self._test(
   497          AfterAny(Repeatedly(AfterCount(42)), DefaultTrigger()),
   498          0,
   499          DataLossReason.NO_POTENTIAL_LOSS)
   500  
   501    def test_after_all_some_may_finish(self):
   502      self._test(
   503          AfterAll(AfterCount(1), DefaultTrigger()),
   504          0,
   505          DataLossReason.NO_POTENTIAL_LOSS)
   506  
   507    def test_afer_all_all_may_finish(self):
   508      self._test(
   509          AfterAll(AfterCount(42), AfterProcessingTime(42)),
   510          0,
   511          DataLossReason.MAY_FINISH)
   512  
   513    def test_after_each_at_least_one_safe(self):
   514      self._test(
   515          AfterEach(AfterCount(1), DefaultTrigger(), AfterCount(2)),
   516          0,
   517          DataLossReason.NO_POTENTIAL_LOSS)
   518  
   519    def test_after_each_all_may_finish(self):
   520      self._test(
   521          AfterEach(AfterCount(1), AfterCount(2), AfterCount(3)),
   522          0,
   523          DataLossReason.MAY_FINISH)
   524  
   525  
   526  class RunnerApiTest(unittest.TestCase):
   527    def test_trigger_encoding(self):
   528      for trigger_fn in (DefaultTrigger(),
   529                         AfterAll(AfterCount(1), AfterCount(10)),
   530                         AfterAny(AfterCount(10), AfterCount(100)),
   531                         AfterWatermark(early=AfterCount(1000)),
   532                         AfterWatermark(early=AfterCount(1000),
   533                                        late=AfterCount(1)),
   534                         Repeatedly(AfterCount(100)),
   535                         trigger.OrFinally(AfterCount(3), AfterCount(10))):
   536        context = pipeline_context.PipelineContext()
   537        self.assertEqual(
   538            trigger_fn,
   539            TriggerFn.from_runner_api(trigger_fn.to_runner_api(context), context))
   540  
   541  
   542  class TriggerPipelineTest(unittest.TestCase):
   543    def test_after_processing_time(self):
   544      test_options = PipelineOptions(
   545          flags=['--allow_unsafe_triggers', '--streaming'])
   546      with TestPipeline(options=test_options) as p:
   547  
   548        total_elements_in_trigger = 4
   549        processing_time_delay = 2
   550        window_size = 10
   551  
   552        # yapf: disable
   553        test_stream = TestStream()
   554        for i in range(total_elements_in_trigger):
   555          (test_stream
   556           .advance_processing_time(
   557              processing_time_delay / total_elements_in_trigger)
   558           .add_elements([('key', i)])
   559           )
   560  
   561        test_stream.advance_processing_time(processing_time_delay)
   562  
   563        # Add dropped elements
   564        (test_stream
   565           .advance_processing_time(0.1)
   566           .add_elements([('key', "dropped-1")])
   567           .advance_processing_time(0.1)
   568           .add_elements([('key', "dropped-2")])
   569        )
   570  
   571        (test_stream
   572         .advance_processing_time(processing_time_delay)
   573         .advance_watermark_to_infinity()
   574         )
   575        # yapf: enable
   576  
   577        results = (
   578            p
   579            | test_stream
   580            | beam.WindowInto(
   581                FixedWindows(window_size),
   582                trigger=AfterProcessingTime(processing_time_delay),
   583                accumulation_mode=AccumulationMode.DISCARDING)
   584            | beam.GroupByKey()
   585            | beam.Map(lambda x: x[1]))
   586  
   587        assert_that(results, equal_to([list(range(total_elements_in_trigger))]))
   588  
   589    def test_repeatedly_after_processing_time(self):
   590      test_options = PipelineOptions(flags=['--streaming'])
   591      with TestPipeline(options=test_options) as p:
   592        total_elements = 7
   593        processing_time_delay = 2
   594        window_size = 10
   595        # yapf: disable
   596        test_stream = TestStream()
   597        for i in range(total_elements):
   598          (test_stream
   599           .advance_processing_time(processing_time_delay - 0.01)
   600           .add_elements([('key', i)])
   601           )
   602  
   603        (test_stream
   604         .advance_processing_time(processing_time_delay)
   605         .advance_watermark_to_infinity()
   606         )
   607        # yapf: enable
   608  
   609        results = (
   610            p
   611            | test_stream
   612            | beam.WindowInto(
   613                FixedWindows(window_size),
   614                trigger=Repeatedly(AfterProcessingTime(processing_time_delay)),
   615                accumulation_mode=AccumulationMode.DISCARDING)
   616            | beam.GroupByKey()
   617            | beam.Map(lambda x: x[1]))
   618  
   619        expected = [[i, i + 1]
   620                    for i in range(total_elements - total_elements % 2)
   621                    if i % 2 == 0]
   622        expected += [] if total_elements % 2 == 0 else [[total_elements - 1]]
   623  
   624        assert_that(results, equal_to(expected))
   625  
   626    def test_after_count(self):
   627      test_options = PipelineOptions(flags=['--allow_unsafe_triggers'])
   628      with TestPipeline(options=test_options) as p:
   629  
   630        def construct_timestamped(k, t):
   631          return TimestampedValue((k, t), t)
   632  
   633        def format_result(k, vs):
   634          return ('%s-%s' % (k, len(list(vs))), set(vs))
   635  
   636        result = (
   637            p
   638            | beam.Create([1, 2, 3, 4, 5, 10, 11])
   639            | beam.FlatMap(lambda t: [('A', t), ('B', t + 5)])
   640            | beam.MapTuple(construct_timestamped)
   641            | beam.WindowInto(
   642                FixedWindows(10),
   643                trigger=AfterCount(3),
   644                accumulation_mode=AccumulationMode.DISCARDING)
   645            | beam.GroupByKey()
   646            | beam.MapTuple(format_result))
   647        assert_that(
   648            result,
   649            equal_to(
   650                list({
   651                    'A-5': {1, 2, 3, 4, 5},
   652                    # A-10, A-11 never emitted due to AfterCount(3) never firing.
   653                    'B-4': {6, 7, 8, 9},
   654                    'B-3': {10, 15, 16},
   655                }.items())))
   656  
   657    def test_after_count_streaming(self):
   658      test_options = PipelineOptions(
   659          flags=['--allow_unsafe_triggers', '--streaming'])
   660      with TestPipeline(options=test_options) as p:
   661        # yapf: disable
   662        test_stream = (
   663            TestStream()
   664            .advance_watermark_to(0)
   665            .add_elements([('A', 1), ('A', 2), ('A', 3)])
   666            .add_elements([('A', 4), ('A', 5), ('A', 6)])
   667            .add_elements([('B', 1), ('B', 2), ('B', 3)])
   668            .advance_watermark_to_infinity())
   669        # yapf: enable
   670  
   671        results = (
   672            p
   673            | test_stream
   674            | beam.WindowInto(
   675                FixedWindows(10),
   676                trigger=AfterCount(3),
   677                accumulation_mode=AccumulationMode.ACCUMULATING)
   678            | beam.GroupByKey())
   679  
   680        assert_that(
   681            results,
   682            equal_to(list({
   683              'A': [1, 2, 3], # 4 - 6 discarded because trigger finished
   684              'B': [1, 2, 3]}.items())))
   685  
   686    def test_always(self):
   687      with TestPipeline() as p:
   688  
   689        def construct_timestamped(k, t):
   690          return TimestampedValue((k, t), t)
   691  
   692        def format_result(k, vs):
   693          return ('%s-%s' % (k, len(list(vs))), set(vs))
   694  
   695        result = (
   696            p
   697            | beam.Create([1, 1, 2, 3, 4, 5, 10, 11])
   698            | beam.FlatMap(lambda t: [('A', t), ('B', t + 5)])
   699            | beam.MapTuple(construct_timestamped)
   700            | beam.WindowInto(
   701                FixedWindows(10),
   702                trigger=Always(),
   703                accumulation_mode=AccumulationMode.DISCARDING)
   704            | beam.GroupByKey()
   705            | beam.MapTuple(format_result))
   706        assert_that(
   707            result,
   708            equal_to(
   709                list({
   710                    'A-2': {10, 11},
   711                    # Elements out of windows are also emitted.
   712                    'A-6': {1, 2, 3, 4, 5},
   713                    # A,1 is emitted twice.
   714                    'B-5': {6, 7, 8, 9},
   715                    # B,6 is emitted twice.
   716                    'B-3': {10, 15, 16},
   717                }.items())))
   718  
   719    def test_never(self):
   720      with TestPipeline() as p:
   721  
   722        def construct_timestamped(k, t):
   723          return TimestampedValue((k, t), t)
   724  
   725        def format_result(k, vs):
   726          return ('%s-%s' % (k, len(list(vs))), set(vs))
   727  
   728        result = (
   729            p
   730            | beam.Create([1, 1, 2, 3, 4, 5, 10, 11])
   731            | beam.FlatMap(lambda t: [('A', t), ('B', t + 5)])
   732            | beam.MapTuple(construct_timestamped)
   733            | beam.WindowInto(
   734                FixedWindows(10),
   735                trigger=_Never(),
   736                accumulation_mode=AccumulationMode.DISCARDING)
   737            | beam.GroupByKey()
   738            | beam.MapTuple(format_result))
   739        assert_that(
   740            result,
   741            equal_to(
   742                list({
   743                    'A-2': {10, 11},
   744                    'A-6': {1, 2, 3, 4, 5},
   745                    'B-5': {6, 7, 8, 9},
   746                    'B-3': {10, 15, 16},
   747                }.items())))
   748  
   749    def test_multiple_accumulating_firings(self):
   750      # PCollection will contain elements from 1 to 10.
   751      elements = [i for i in range(1, 11)]
   752  
   753      ts = TestStream().advance_watermark_to(0)
   754      for i in elements:
   755        ts.add_elements([('key', str(i))])
   756        if i % 5 == 0:
   757          ts.advance_watermark_to(i)
   758          ts.advance_processing_time(5)
   759      ts.advance_watermark_to_infinity()
   760  
   761      options = PipelineOptions()
   762      options.view_as(StandardOptions).streaming = True
   763      with TestPipeline(options=options) as p:
   764        records = (
   765            p
   766            | ts
   767            | beam.WindowInto(
   768                FixedWindows(10),
   769                accumulation_mode=trigger.AccumulationMode.ACCUMULATING,
   770                trigger=AfterWatermark(
   771                    early=AfterAll(AfterCount(1), AfterProcessingTime(5))))
   772            | beam.GroupByKey()
   773            | beam.FlatMap(lambda x: x[1]))
   774  
   775      # The trigger should fire twice. Once after 5 seconds, and once after 10.
   776      # The firings should accumulate the output.
   777      first_firing = [str(i) for i in elements if i <= 5]
   778      second_firing = [str(i) for i in elements]
   779      assert_that(records, equal_to(first_firing + second_firing))
   780  
   781    def test_on_pane_watermark_hold_no_pipeline_stall(self):
   782      """A regression test added for
   783      ttps://issues.apache.org/jira/browse/BEAM-10054."""
   784      START_TIMESTAMP = 1534842000
   785  
   786      test_stream = TestStream()
   787      test_stream.add_elements(['a'])
   788      test_stream.advance_processing_time(START_TIMESTAMP + 1)
   789      test_stream.advance_watermark_to(START_TIMESTAMP + 1)
   790      test_stream.add_elements(['b'])
   791      test_stream.advance_processing_time(START_TIMESTAMP + 2)
   792      test_stream.advance_watermark_to(START_TIMESTAMP + 2)
   793  
   794      with TestPipeline(options=PipelineOptions(
   795          ['--streaming', '--allow_unsafe_triggers'])) as p:
   796        # pylint: disable=expression-not-assigned
   797        (
   798            p
   799            | 'TestStream' >> test_stream
   800            | 'timestamp' >>
   801            beam.Map(lambda x: beam.window.TimestampedValue(x, START_TIMESTAMP))
   802            | 'kv' >> beam.Map(lambda x: (x, x))
   803            | 'window_1m' >> beam.WindowInto(
   804                beam.window.FixedWindows(60),
   805                trigger=trigger.AfterAny(
   806                    trigger.AfterProcessingTime(3600), trigger.AfterWatermark()),
   807                accumulation_mode=trigger.AccumulationMode.DISCARDING)
   808            | 'group_by_key' >> beam.GroupByKey()
   809            | 'filter' >> beam.Map(lambda x: x))
   810  
   811  
   812  class TranscriptTest(unittest.TestCase):
   813  
   814    # We must prepend an underscore to this name so that the open-source unittest
   815    # runner does not execute this method directly as a test.
   816    @classmethod
   817    def _create_test(cls, spec):
   818      counter = 0
   819      name = spec.get('name', 'unnamed')
   820      unique_name = 'test_' + name
   821      while hasattr(cls, unique_name):
   822        counter += 1
   823        unique_name = 'test_%s_%d' % (name, counter)
   824      test_method = lambda self: self._run_log_test(spec)
   825      test_method.__name__ = unique_name
   826      test_method.__test__ = True
   827      setattr(cls, unique_name, test_method)
   828  
   829    # We must prepend an underscore to this name so that the open-source unittest
   830    # runner does not execute this method directly as a test.
   831    @classmethod
   832    def _create_tests(cls, transcript_filename):
   833      for spec in yaml.load_all(open(transcript_filename),
   834                                Loader=yaml.SafeLoader):
   835        cls._create_test(spec)
   836  
   837    def _run_log_test(self, spec):
   838      if 'error' in spec:
   839        self.assertRaisesRegex(Exception, spec['error'], self._run_log, spec)
   840      else:
   841        self._run_log(spec)
   842  
   843    def _run_log(self, spec):
   844      def parse_int_list(s):
   845        """Parses strings like '[1, 2, 3]'."""
   846        s = s.strip()
   847        assert s[0] == '[' and s[-1] == ']', s
   848        if not s[1:-1].strip():
   849          return []
   850        return [int(x) for x in s[1:-1].split(',')]
   851  
   852      def split_args(s):
   853        """Splits 'a, b, [c, d]' into ['a', 'b', '[c, d]']."""
   854        args = []
   855        start = 0
   856        depth = 0
   857        for ix in range(len(s)):
   858          c = s[ix]
   859          if c in '({[':
   860            depth += 1
   861          elif c in ')}]':
   862            depth -= 1
   863          elif c == ',' and depth == 0:
   864            args.append(s[start:ix].strip())
   865            start = ix + 1
   866        assert depth == 0, s
   867        args.append(s[start:].strip())
   868        return args
   869  
   870      def parse(s, names):
   871        """Parse (recursive) 'Foo(arg, kw=arg)' for Foo in the names dict."""
   872        s = s.strip()
   873        if s in names:
   874          return names[s]
   875        elif s[0] == '[':
   876          return parse_int_list(s)
   877        elif '(' in s:
   878          assert s[-1] == ')', s
   879          callee = parse(s[:s.index('(')], names)
   880          posargs = []
   881          kwargs = {}
   882          for arg in split_args(s[s.index('(') + 1:-1]):
   883            if '=' in arg:
   884              kw, value = arg.split('=', 1)
   885              kwargs[kw] = parse(value, names)
   886            else:
   887              posargs.append(parse(arg, names))
   888          return callee(*posargs, **kwargs)
   889        else:
   890          try:
   891            return int(s)
   892          except ValueError:
   893            raise ValueError('Unknown function: %s' % s)
   894  
   895      def parse_fn(s, names):
   896        """Like parse(), but implicitly calls no-arg constructors."""
   897        fn = parse(s, names)
   898        if isinstance(fn, type):
   899          return fn()
   900        return fn
   901  
   902      # pylint: disable=wrong-import-order, wrong-import-position
   903      from apache_beam.transforms import window as window_module
   904      # pylint: enable=wrong-import-order, wrong-import-position
   905      window_fn_names = dict(window_module.__dict__)
   906      # yapf: disable
   907      window_fn_names.update({
   908          'CustomTimestampingFixedWindowsWindowFn':
   909              CustomTimestampingFixedWindowsWindowFn
   910      })
   911      # yapf: enable
   912      trigger_names = {'Default': DefaultTrigger}
   913      trigger_names.update(trigger.__dict__)
   914  
   915      window_fn = parse_fn(
   916          spec.get('window_fn', 'GlobalWindows'), window_fn_names)
   917      trigger_fn = parse_fn(spec.get('trigger_fn', 'Default'), trigger_names)
   918      accumulation_mode = getattr(
   919          AccumulationMode, spec.get('accumulation_mode', 'ACCUMULATING').upper())
   920      timestamp_combiner = getattr(
   921          TimestampCombiner,
   922          spec.get('timestamp_combiner', 'OUTPUT_AT_EOW').upper())
   923      allowed_lateness = spec.get('allowed_lateness', 0.000)
   924  
   925      def only_element(xs):
   926        x, = list(xs)
   927        return x
   928  
   929      transcript = [only_element(line.items()) for line in spec['transcript']]
   930  
   931      self._execute(
   932          window_fn,
   933          trigger_fn,
   934          accumulation_mode,
   935          timestamp_combiner,
   936          allowed_lateness,
   937          transcript,
   938          spec)
   939  
   940  
   941  def _windowed_value_info(windowed_value):
   942    # Currently some runners operate at the millisecond level, and some at the
   943    # microsecond level.  Trigger transcript timestamps are expressed as
   944    # integral units of the finest granularity, whatever that may be.
   945    # In these tests we interpret them as integral seconds and then truncate
   946    # the results to integral seconds to allow for portability across
   947    # different sub-second resolutions.
   948    window, = windowed_value.windows
   949    return {
   950        'window': [int(window.start), int(window.max_timestamp())],
   951        'values': sorted(windowed_value.value),
   952        'timestamp': int(windowed_value.timestamp),
   953        'index': windowed_value.pane_info.index,
   954        'nonspeculative_index': windowed_value.pane_info.nonspeculative_index,
   955        'early': windowed_value.pane_info.timing == PaneInfoTiming.EARLY,
   956        'late': windowed_value.pane_info.timing == PaneInfoTiming.LATE,
   957        'final': windowed_value.pane_info.is_last,
   958    }
   959  
   960  
   961  def _windowed_value_info_map_fn(
   962      k,
   963      vs,
   964      window=beam.DoFn.WindowParam,
   965      t=beam.DoFn.TimestampParam,
   966      p=beam.DoFn.PaneInfoParam):
   967    return (
   968        k,
   969        _windowed_value_info(
   970            WindowedValue(vs, windows=[window], timestamp=t, pane_info=p)))
   971  
   972  
   973  def _windowed_value_info_check(actual, expected, key=None):
   974  
   975    key_string = ' for %s' % key if key else ''
   976  
   977    def format(panes):
   978      return '\n[%s]\n' % '\n '.join(
   979          str(pane)
   980          for pane in sorted(panes, key=lambda pane: pane.get('timestamp', None)))
   981  
   982    if len(actual) > len(expected):
   983      raise AssertionError(
   984          'Unexpected output%s: expected %s but got %s' %
   985          (key_string, format(expected), format(actual)))
   986    elif len(expected) > len(actual):
   987      raise AssertionError(
   988          'Unmatched output%s: expected %s but got %s' %
   989          (key_string, format(expected), format(actual)))
   990    else:
   991  
   992      def diff(actual, expected):
   993        for key in sorted(expected.keys(), reverse=True):
   994          if key in actual:
   995            if actual[key] != expected[key]:
   996              return key
   997  
   998      for output in actual:
   999        diffs = [diff(output, pane) for pane in expected]
  1000        if all(diffs):
  1001          raise AssertionError(
  1002              'Unmatched output%s: %s not found in %s (diffs in %s)' %
  1003              (key_string, output, format(expected), diffs))
  1004  
  1005  
  1006  class _ConcatCombineFn(beam.CombineFn):
  1007    create_accumulator = lambda self: []  # type: ignore[var-annotated]
  1008    add_input = lambda self, acc, element: acc.append(element) or acc
  1009    merge_accumulators = lambda self, accs: sum(accs, [])  # type: ignore[var-annotated]
  1010    extract_output = lambda self, acc: acc
  1011  
  1012  
  1013  class TriggerDriverTranscriptTest(TranscriptTest):
  1014    def _execute(
  1015        self,
  1016        window_fn,
  1017        trigger_fn,
  1018        accumulation_mode,
  1019        timestamp_combiner,
  1020        allowed_lateness,
  1021        transcript,
  1022        unused_spec):
  1023  
  1024      driver = GeneralTriggerDriver(
  1025          Windowing(
  1026              window_fn,
  1027              trigger_fn,
  1028              accumulation_mode,
  1029              timestamp_combiner,
  1030              allowed_lateness),
  1031          TestClock())
  1032      state = InMemoryUnmergedState()
  1033      output = []
  1034      watermark = MIN_TIMESTAMP
  1035  
  1036      def fire_timers():
  1037        to_fire = state.get_and_clear_timers(watermark)
  1038        while to_fire:
  1039          for timer_window, (name, time_domain, t_timestamp, _) in to_fire:
  1040            for wvalue in driver.process_timer(timer_window,
  1041                                               name,
  1042                                               time_domain,
  1043                                               t_timestamp,
  1044                                               state):
  1045              output.append(_windowed_value_info(wvalue))
  1046          to_fire = state.get_and_clear_timers(watermark)
  1047  
  1048      for action, params in transcript:
  1049  
  1050        if action != 'expect':
  1051          # Fail if we have output that was not expected in the transcript.
  1052          self.assertEqual([],
  1053                           output,
  1054                           msg='Unexpected output: %s before %s: %s' %
  1055                           (output, action, params))
  1056  
  1057        if action == 'input':
  1058          bundle = [
  1059              WindowedValue(t, t, window_fn.assign(WindowFn.AssignContext(t, t)))
  1060              for t in params
  1061          ]
  1062          output = [
  1063              _windowed_value_info(wv) for wv in driver.process_elements(
  1064                  state, bundle, watermark, watermark)
  1065          ]
  1066          fire_timers()
  1067  
  1068        elif action == 'watermark':
  1069          watermark = params
  1070          fire_timers()
  1071  
  1072        elif action == 'expect':
  1073          for expected_output in params:
  1074            for candidate in output:
  1075              if all(candidate[k] == expected_output[k] for k in candidate
  1076                     if k in expected_output):
  1077                output.remove(candidate)
  1078                break
  1079            else:
  1080              self.fail('Unmatched output %s in %s' % (expected_output, output))
  1081  
  1082        elif action == 'state':
  1083          # TODO(robertwb): Implement once we support allowed lateness.
  1084          pass
  1085  
  1086        else:
  1087          self.fail('Unknown action: ' + action)
  1088  
  1089      # Fail if we have output that was not expected in the transcript.
  1090      self.assertEqual([], output, msg='Unexpected output: %s' % output)
  1091  
  1092  
  1093  class BaseTestStreamTranscriptTest(TranscriptTest):
  1094    """A suite of TestStream-based tests based on trigger transcript entries.
  1095    """
  1096    def _execute(
  1097        self,
  1098        window_fn,
  1099        trigger_fn,
  1100        accumulation_mode,
  1101        timestamp_combiner,
  1102        allowed_lateness,
  1103        transcript,
  1104        spec):
  1105  
  1106      runner_name = TestPipeline().runner.__class__.__name__
  1107      if runner_name in spec.get('broken_on', ()):
  1108        self.skipTest('Known to be broken on %s' % runner_name)
  1109  
  1110      is_order_agnostic = (
  1111          isinstance(trigger_fn, DefaultTrigger) and
  1112          accumulation_mode == AccumulationMode.ACCUMULATING)
  1113  
  1114      if is_order_agnostic:
  1115        reshuffle_seed = random.randrange(1 << 20)
  1116        keys = [
  1117            u'original',
  1118            u'reversed',
  1119            u'reshuffled(%s)' % reshuffle_seed,
  1120            u'one-element-bundles',
  1121            u'one-element-bundles-reversed',
  1122            u'two-element-bundles'
  1123        ]
  1124      else:
  1125        keys = [u'key1', u'key2']
  1126  
  1127      # Elements are encoded as a json strings to allow other languages to
  1128      # decode elements while executing the test stream.
  1129      # TODO(https://github.com/apache/beam/issues/19934): Eliminate these
  1130      # gymnastics.
  1131      test_stream = TestStream(coder=coders.StrUtf8Coder()).with_output_types(str)
  1132      for action, params in transcript:
  1133        if action == 'expect':
  1134          test_stream.add_elements([json.dumps(('expect', params))])
  1135        else:
  1136          test_stream.add_elements([json.dumps(('expect', []))])
  1137          if action == 'input':
  1138  
  1139            def keyed(key, values):
  1140              return [json.dumps(('input', (key, v))) for v in values]
  1141  
  1142            if is_order_agnostic:
  1143              # Must match keys above.
  1144              test_stream.add_elements(keyed('original', params))
  1145              test_stream.add_elements(keyed('reversed', reversed(params)))
  1146              r = random.Random(reshuffle_seed)
  1147              reshuffled = list(params)
  1148              r.shuffle(reshuffled)
  1149              test_stream.add_elements(
  1150                  keyed('reshuffled(%s)' % reshuffle_seed, reshuffled))
  1151              for v in params:
  1152                test_stream.add_elements(keyed('one-element-bundles', [v]))
  1153              for v in reversed(params):
  1154                test_stream.add_elements(
  1155                    keyed('one-element-bundles-reversed', [v]))
  1156              for ix in range(0, len(params), 2):
  1157                test_stream.add_elements(
  1158                    keyed('two-element-bundles', params[ix:ix + 2]))
  1159            else:
  1160              for key in keys:
  1161                test_stream.add_elements(keyed(key, params))
  1162          elif action == 'watermark':
  1163            test_stream.advance_watermark_to(params)
  1164          elif action == 'clock':
  1165            test_stream.advance_processing_time(params)
  1166          elif action == 'state':
  1167            pass  # Requires inspection of implementation details.
  1168          else:
  1169            raise ValueError('Unexpected action: %s' % action)
  1170      test_stream.add_elements([json.dumps(('expect', []))])
  1171      test_stream.advance_watermark_to_infinity()
  1172  
  1173      read_test_stream = test_stream | beam.Map(json.loads)
  1174  
  1175      class Check(beam.DoFn):
  1176        """A StatefulDoFn that verifies outputs are produced as expected.
  1177  
  1178        This DoFn takes in two kinds of inputs, actual outputs and
  1179        expected outputs.  When an actual output is received, it is buffered
  1180        into state, and when an expected output is received, this buffered
  1181        state is retrieved and compared against the expected value(s) to ensure
  1182        they match.
  1183  
  1184        The key is ignored, but all items must be on the same key to share state.
  1185        """
  1186        def __init__(self, allow_out_of_order=True):
  1187          # Some runners don't support cross-stage TestStream semantics.
  1188          self.allow_out_of_order = allow_out_of_order
  1189  
  1190        def process(
  1191            self,
  1192            element,
  1193            seen=beam.DoFn.StateParam(
  1194                beam.transforms.userstate.BagStateSpec(
  1195                    'seen', beam.coders.FastPrimitivesCoder())),
  1196            expected=beam.DoFn.StateParam(
  1197                beam.transforms.userstate.BagStateSpec(
  1198                    'expected', beam.coders.FastPrimitivesCoder()))):
  1199          key, (action, data) = element
  1200  
  1201          if self.allow_out_of_order:
  1202            if action == 'expect' and not list(seen.read()):
  1203              if data:
  1204                expected.add(data)
  1205              return
  1206            elif action == 'actual' and list(expected.read()):
  1207              seen.add(data)
  1208              all_data = list(seen.read())
  1209              all_expected = list(expected.read())
  1210              if len(all_data) == len(all_expected[0]):
  1211                expected.clear()
  1212                for expect in all_expected[1:]:
  1213                  expected.add(expect)
  1214                action, data = 'expect', all_expected[0]
  1215              else:
  1216                return
  1217  
  1218          if action == 'actual':
  1219            seen.add(data)
  1220  
  1221          elif action == 'expect':
  1222            actual = list(seen.read())
  1223            seen.clear()
  1224            _windowed_value_info_check(actual, data, key)
  1225  
  1226          else:
  1227            raise ValueError('Unexpected action: %s' % action)
  1228  
  1229      @ptransform.ptransform_fn
  1230      def CheckAggregation(inputs_and_expected, aggregation):
  1231        # Split the test stream into a branch of to-be-processed elements, and
  1232        # a branch of expected results.
  1233        inputs, expected = (
  1234            inputs_and_expected
  1235            | beam.MapTuple(
  1236                lambda tag, value: beam.pvalue.TaggedOutput(tag, value),
  1237                ).with_outputs('input', 'expect'))
  1238  
  1239        # Process the inputs with the given windowing to produce actual outputs.
  1240        outputs = (
  1241            inputs
  1242            | beam.MapTuple(
  1243                lambda key, value: TimestampedValue((key, value), value))
  1244            | beam.WindowInto(
  1245                window_fn,
  1246                trigger=trigger_fn,
  1247                accumulation_mode=accumulation_mode,
  1248                timestamp_combiner=timestamp_combiner,
  1249                allowed_lateness=allowed_lateness)
  1250            | aggregation
  1251            | beam.MapTuple(_windowed_value_info_map_fn)
  1252            # Place outputs back into the global window to allow flattening
  1253            # and share a single state in Check.
  1254            | 'Global' >> beam.WindowInto(beam.transforms.window.GlobalWindows()))
  1255        # Feed both the expected and actual outputs to Check() for comparison.
  1256        tagged_expected = (
  1257            expected | beam.FlatMap(
  1258                lambda value: [(key, ('expect', value)) for key in keys]))
  1259        tagged_outputs = (
  1260            outputs | beam.MapTuple(lambda key, value: (key, ('actual', value))))
  1261        # pylint: disable=expression-not-assigned
  1262        ([tagged_expected, tagged_outputs]
  1263         | beam.Flatten()
  1264         | beam.ParDo(Check(self.allow_out_of_order)))
  1265  
  1266      with TestPipeline() as p:
  1267        # TODO(https://github.com/apache/beam/issues/19933): Pass this during
  1268        # pipeline construction.
  1269        p._options.view_as(StandardOptions).streaming = True
  1270        p._options.view_as(TypeOptions).allow_unsafe_triggers = True
  1271  
  1272        # We can have at most one test stream per pipeline, so we share it.
  1273        inputs_and_expected = p | read_test_stream
  1274        _ = inputs_and_expected | CheckAggregation(beam.GroupByKey())
  1275        _ = inputs_and_expected | CheckAggregation(
  1276            beam.CombinePerKey(_ConcatCombineFn()))
  1277  
  1278  
  1279  class TestStreamTranscriptTest(BaseTestStreamTranscriptTest):
  1280    allow_out_of_order = False
  1281  
  1282  
  1283  class WeakTestStreamTranscriptTest(BaseTestStreamTranscriptTest):
  1284    allow_out_of_order = True
  1285  
  1286  
  1287  class BatchTranscriptTest(TranscriptTest):
  1288    def _execute(
  1289        self,
  1290        window_fn,
  1291        trigger_fn,
  1292        accumulation_mode,
  1293        timestamp_combiner,
  1294        allowed_lateness,
  1295        transcript,
  1296        spec):
  1297      if timestamp_combiner == TimestampCombiner.OUTPUT_AT_EARLIEST_TRANSFORMED:
  1298        self.skipTest(
  1299            'Non-fnapi timestamp combiner: %s' % spec.get('timestamp_combiner'))
  1300  
  1301      if accumulation_mode != AccumulationMode.ACCUMULATING:
  1302        self.skipTest('Batch mode only makes sense for accumulating.')
  1303  
  1304      watermark = MIN_TIMESTAMP
  1305      for action, params in transcript:
  1306        if action == 'watermark':
  1307          watermark = params
  1308        elif action == 'input':
  1309          if any(t <= watermark for t in params):
  1310            self.skipTest('Batch mode never has late data.')
  1311  
  1312      inputs = sum([vs for action, vs in transcript if action == 'input'], [])
  1313      final_panes_by_window = {}
  1314      for action, params in transcript:
  1315        if action == 'expect':
  1316          for expected in params:
  1317            trimmed = {}
  1318            for field in ('window', 'values', 'timestamp'):
  1319              if field in expected:
  1320                trimmed[field] = expected[field]
  1321            final_panes_by_window[tuple(expected['window'])] = trimmed
  1322      final_panes = list(final_panes_by_window.values())
  1323  
  1324      if window_fn.is_merging():
  1325        merged_away = set()
  1326  
  1327        class MergeContext(WindowFn.MergeContext):
  1328          def merge(_, to_be_merged, merge_result):
  1329            for window in to_be_merged:
  1330              if window != merge_result:
  1331                merged_away.add(window)
  1332  
  1333        all_windows = [IntervalWindow(*pane['window']) for pane in final_panes]
  1334        window_fn.merge(MergeContext(all_windows))
  1335        final_panes = [
  1336            pane for pane in final_panes
  1337            if IntervalWindow(*pane['window']) not in merged_away
  1338        ]
  1339  
  1340      with TestPipeline() as p:
  1341        input_pc = (
  1342            p
  1343            | beam.Create(inputs)
  1344            | beam.Map(lambda t: TimestampedValue(('key', t), t))
  1345            | beam.WindowInto(
  1346                window_fn,
  1347                trigger=trigger_fn,
  1348                accumulation_mode=accumulation_mode,
  1349                timestamp_combiner=timestamp_combiner,
  1350                allowed_lateness=allowed_lateness))
  1351  
  1352        grouped = input_pc | 'Grouped' >> (
  1353            beam.GroupByKey()
  1354            | beam.MapTuple(_windowed_value_info_map_fn)
  1355            | beam.MapTuple(lambda _, value: value))
  1356  
  1357        combined = input_pc | 'Combined' >> (
  1358            beam.CombinePerKey(_ConcatCombineFn())
  1359            | beam.MapTuple(_windowed_value_info_map_fn)
  1360            | beam.MapTuple(lambda _, value: value))
  1361  
  1362        assert_that(
  1363            grouped,
  1364            lambda actual: _windowed_value_info_check(actual, final_panes),
  1365            label='CheckGrouped')
  1366  
  1367        assert_that(
  1368            combined,
  1369            lambda actual: _windowed_value_info_check(actual, final_panes),
  1370            label='CheckCombined')
  1371  
  1372  
  1373  TRANSCRIPT_TEST_FILE = os.path.join(
  1374      os.path.dirname(__file__),
  1375      '..',
  1376      'testing',
  1377      'data',
  1378      'trigger_transcripts.yaml')
  1379  if os.path.exists(TRANSCRIPT_TEST_FILE):
  1380    TriggerDriverTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
  1381    TestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
  1382    WeakTestStreamTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
  1383    BatchTranscriptTest._create_tests(TRANSCRIPT_TEST_FILE)
  1384  
  1385  if __name__ == '__main__':
  1386    unittest.main()